1 /*
2  * Copyright (c) 2023 Shenzhen Kaihong Digital Industry Development Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tcp_server.h"
17 #include <arpa/inet.h>
18 #include "common/media_log.h"
19 #include "network/session/tcp_session.h"
20 #include "network/socket/socket_utils.h"
21 #include "network/socket/tcp_socket.h"
22 #include "utils/utils.h"
23 namespace OHOS {
24 namespace Sharing {
~TcpServer()25 TcpServer::~TcpServer()
26 {
27     SHARING_LOGD("trace.");
28     Stop();
29 }
30 
TcpServer()31 TcpServer::TcpServer()
32 {
33     SHARING_LOGD("trace.");
34 }
35 
Start(uint16_t port, const std::string &host, bool enableReuse, uint32_t backlog)36 bool TcpServer::Start(uint16_t port, const std::string &host, bool enableReuse, uint32_t backlog)
37 {
38     SHARING_LOGD("server ip:%{public}s, Port:%{public}d, thread_id: %{public}llu.", GetAnonyString(host).c_str(), port,
39                  GetThreadId());
40     socket_ = std::make_unique<TcpSocket>();
41     if (socket_) {
42         if (socket_->Bind(port, host, enableReuse, backlog)) {
43             SHARING_LOGD("start success fd: %{public}d.", socket_->GetLocalFd());
44             socketLocalFd_ = socket_->GetLocalFd();
45 
46             auto eventRunner = OHOS::AppExecFwk::EventRunner::Create(true);
47             eventHandler_ = std::make_shared<TcpServerEventHandler>();
48             eventHandler_->SetServer(shared_from_this());
49             eventHandler_->SetEventRunner(eventRunner);
50             eventRunner->Run();
51 
52             eventListener_ = std::make_shared<TcpServerEventListener>();
53             eventListener_->SetServer(shared_from_this());
54 
55             return eventListener_->AddFdListener(socket_->GetLocalFd(), eventListener_, eventHandler_);
56         }
57     }
58 
59     SHARING_LOGE("start failed!");
60     return false;
61 }
62 
Stop()63 void TcpServer::Stop()
64 {
65     SHARING_LOGD("trace.");
66     std::unique_lock<std::shared_mutex> lk(mutex_);
67     if (socket_ != nullptr) {
68         if (eventListener_) {
69             eventListener_->RemoveFdListener(socket_->GetLocalFd());
70         }
71         SocketUtils::ShutDownSocket(socket_->GetLocalFd());
72         SocketUtils::CloseSocket(socket_->GetLocalFd());
73         for (auto it = sessions_.begin(); it != sessions_.end(); it++) {
74             SHARING_LOGD("closeClientSocket:erase fd: %{public}d,size: %{public}zu.", it->first, sessions_.size());
75             if (it->second) {
76                 it->second->Shutdown();
77                 it->second.reset();
78             }
79         }
80         sessions_.clear();
81         socket_.reset();
82     }
83 }
84 
CloseClientSocket(int32_t fd)85 void TcpServer::CloseClientSocket(int32_t fd)
86 {
87     SHARING_LOGD("fd: %{public}d.", fd);
88     if (fd > 0) {
89         auto itemItr = sessions_.find(fd);
90         if (itemItr != sessions_.end()) {
91             if (itemItr->second) {
92                 itemItr->second->Shutdown();
93                 itemItr->second.reset();
94             }
95             sessions_.erase(itemItr);
96             SHARING_LOGD("erase fd: %{public}d.", fd);
97         }
98     }
99 }
100 
GetSocketInfo()101 SocketInfo::Ptr TcpServer::GetSocketInfo()
102 {
103     SHARING_LOGD("trace.");
104     return socket_;
105 }
106 
OnServerReadable(int32_t fd)107 void TcpServer::OnServerReadable(int32_t fd)
108 {
109     SHARING_LOGD("fd: %{public}d, socketLocalFd: %{public}d, thread_id: %{public}llu.", fd, socketLocalFd_,
110                  GetThreadId());
111     std::unique_lock<std::shared_mutex> lk(mutex_);
112     struct sockaddr_in clientAddr;
113     socklen_t addrLen = sizeof(sockaddr_in);
114     if (fd == socketLocalFd_) {
115         int32_t clientFd = SocketUtils::AcceptSocket(fd, &clientAddr, &addrLen);
116         if (clientFd < 0) {
117             SHARING_LOGE("onReadable accept client error!");
118             return;
119         }
120         SocketUtils::SetNonBlocking(clientFd);
121         SocketUtils::SetNoDelay(clientFd, true);
122         SocketUtils::SetSendBuf(clientFd);
123         SocketUtils::SetRecvBuf(clientFd);
124         SocketUtils::SetCloseWait(clientFd);
125         SocketUtils::SetCloExec(clientFd, true);
126         SocketUtils::SetKeepAlive(clientFd);
127         SHARING_LOGD("onReadable accept client fd: %{public}d.", clientFd);
128         if (socket_) {
129             socket_->socketPeerFd_ = clientFd;
130 
131             std::string strLocalAddr = "";
132             std::string strRemoteAddr = "";
133             uint16_t localPort = 0;
134             uint16_t remotePort = 0;
135             SocketUtils::GetIpPortInfo(clientFd, strLocalAddr, strRemoteAddr, localPort, remotePort);
136 
137             SocketInfo::Ptr socketInfo =
138                 std::make_shared<SocketInfo>(strLocalAddr, strRemoteAddr, fd, clientFd, localPort, remotePort);
139             if (socketInfo) {
140                 socketInfo->SetSocketType(SOCKET_TYPE_TCP);
141                 BaseNetworkSession::Ptr session = std::make_shared<TcpSession>(std::move(socketInfo));
142                 if (session) {
143                     MEDIA_LOGE("[TcpServer] OnReadable new session start.");
144                     sessions_.insert(make_pair(clientFd, std::move(session)));
145                     auto callback = callback_.lock();
146                     if (callback) {
147                         callback->OnAccept(sessions_[clientFd]);
148                     }
149                 } else {
150                     MEDIA_LOGE("onReadable create session failed!");
151                 }
152             } else {
153                 MEDIA_LOGE("onReadable create SocketInfo failed!");
154             }
155         }
156     } else {
157         MEDIA_LOGD("onReadable receive msg!");
158     }
159 }
160 } // namespace Sharing
161 } // namespace OHOS