1 /*
2  * Copyright (c) 2023 Shenzhen Kaihong Digital Industry Development Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "socket_utils.h"
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <iostream>
20 #include <netinet/tcp.h>
21 #include <unistd.h>
22 #include "common/common_macro.h"
23 #include "common/const_def.h"
24 #include "common/media_log.h"
25 #include "utils.h"
26 
27 namespace OHOS {
28 namespace Sharing {
29 uint16_t SocketUtils::minPort_ = MIN_PORT;
30 uint16_t SocketUtils::maxPort_ = MAX_PORT;
31 
CreateTcpServer(const char *ip, unsigned port, int32_t &fd)32 bool SocketUtils::CreateTcpServer(const char *ip, unsigned port, int32_t &fd)
33 {
34     SHARING_LOGD("trace.");
35     RETURN_FALSE_IF_NULL(ip);
36     return CreateSocket(SOCK_STREAM, fd) && SetReuseAddr(fd, true) && SetNoDelay(fd, true) &&
37            BindSocket(fd, ip, port) && ListenSocket(fd);
38 }
39 
GetAvailableUdpPortPair()40 uint16_t SocketUtils::GetAvailableUdpPortPair()
41 {
42     SHARING_LOGD("trace.");
43     static uint16_t gAvailablePort = minPort_;
44 
45     SHARING_LOGD("current udp port: %{public}d.", gAvailablePort);
46     if (gAvailablePort >= maxPort_) {
47         gAvailablePort = minPort_;
48     }
49 
50     uint16_t port = GetAvailableUdpPortPair(gAvailablePort, maxPort_);
51     if (port != 0) {
52         gAvailablePort = port + 2; // 2: pair port
53         return port;
54     }
55 
56     port = GetAvailableUdpPortPair(minPort_, gAvailablePort);
57     if (port != 0) {
58         gAvailablePort = port + 2; // 2: pair port
59     }
60 
61     return port;
62 }
63 
GetAvailableUdpPortPair(uint16_t minPort, uint16_t maxPort)64 uint16_t SocketUtils::GetAvailableUdpPortPair(uint16_t minPort, uint16_t maxPort)
65 {
66     SHARING_LOGD("trace.");
67     if (minPort == maxPort) {
68         return 0;
69     }
70 
71     uint16_t port = minPort;
72     bool portAvalaible = false;
73     while (!portAvalaible) {
74         if (port >= maxPort) {
75             port = 0;
76             portAvalaible = true;
77         } else if (IsUdpPortAvailable(port) && IsUdpPortAvailable(port + 1)) {
78             portAvalaible = true;
79         } else {
80             port += 2; // 2: pair port
81         }
82     }
83 
84     return port;
85 }
86 
IsUdpPortAvailable(uint16_t port)87 bool SocketUtils::IsUdpPortAvailable(uint16_t port)
88 {
89     SHARING_LOGD("trace.");
90     int32_t fd = -1;
91     auto ret = CreateSocket(SOCK_DGRAM, fd) && BindSocket(fd, "", port);
92     if (fd != -1) {
93         CloseSocket(fd);
94     }
95 
96     return ret;
97 }
98 
CreateTcpClient(const char *ip, unsigned port, int32_t &fd, int32_t &ret)99 bool SocketUtils::CreateTcpClient(const char *ip, unsigned port, int32_t &fd, int32_t &ret)
100 {
101     SHARING_LOGD("trace.");
102     RETURN_FALSE_IF_NULL(ip);
103     return CreateSocket(SOCK_STREAM, fd) && ConnectSocket(fd, true, ip, port, ret);
104 }
105 
CreateUdpSession(unsigned port, int32_t &fd)106 bool SocketUtils::CreateUdpSession(unsigned port, int32_t &fd)
107 {
108     SHARING_LOGD("trace.");
109     return CreateSocket(SOCK_DGRAM, fd) && SetRecvBuf(fd) && SetSendBuf(fd) && BindSocket(fd, "", port);
110 }
111 
CreateSocket(int32_t socketType, int32_t &fd)112 bool SocketUtils::CreateSocket(int32_t socketType, int32_t &fd)
113 {
114     SHARING_LOGD("trace.");
115     fd = -1;
116     if (socketType != SOCK_STREAM && socketType != SOCK_DGRAM) {
117         SHARING_LOGE("type error: %{public}d!", socketType);
118     }
119 
120     fd = socket(AF_INET, socketType, (socketType == SOCK_STREAM ? IPPROTO_TCP : IPPROTO_UDP));
121     if (fd < 0) {
122         SHARING_LOGE("error: %{public}s!", strerror(errno));
123         return false;
124     }
125     SHARING_LOGD("success fd: %{public}d.", fd);
126     return true;
127 }
128 
BindSocket(int32_t fd, const std::string &host, uint16_t port)129 bool SocketUtils::BindSocket(int32_t fd, const std::string &host, uint16_t port)
130 {
131     SHARING_LOGD("trace.");
132     struct sockaddr_in addr = {};
133     addr.sin_family = AF_INET;
134     addr.sin_port = htons(port);
135     if (host == "" || host == "::") {
136         addr.sin_addr.s_addr = INADDR_ANY;
137     } else {
138         if (inet_pton(AF_INET, host.c_str(), &addr.sin_addr) <= 0) {
139             SHARING_LOGE("error: %{public}s!", strerror(errno));
140             return false;
141         }
142     }
143     if (::bind(fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
144         SHARING_LOGE("error: %{public}s!", strerror(errno));
145         return false;
146     }
147 
148     return true;
149 }
150 
ListenSocket(int32_t fd, uint32_t backlog)151 bool SocketUtils::ListenSocket(int32_t fd, uint32_t backlog)
152 {
153     SHARING_LOGD("trace.");
154     if (::listen(fd, backlog) == -1) {
155         SHARING_LOGE("error: %{public}s!", strerror(errno));
156         return false;
157     }
158 
159     return true;
160 }
161 
ConnectSocket(int32_t fd, bool isAsync, const std::string &ip, uint16_t port, int32_t &ret)162 bool SocketUtils::ConnectSocket(int32_t fd, bool isAsync, const std::string &ip, uint16_t port, int32_t &ret)
163 {
164     SHARING_LOGD("trace.");
165     if (ip == "") {
166         SHARING_LOGE("ip null!");
167         return false;
168     }
169 
170     struct sockaddr_in serverAddr = {};
171     serverAddr.sin_family = AF_INET;
172     serverAddr.sin_port = htons(port);
173     if (inet_pton(AF_INET, ip.c_str(), &serverAddr.sin_addr) <= 0) {
174         SHARING_LOGE("inet_pton ip error: %{public}s!", strerror(errno));
175         return false;
176     }
177 
178     int32_t res = ::connect(fd, (struct sockaddr *)&serverAddr, sizeof(serverAddr));
179     ret = res;
180     if (isAsync) {
181         if (res == 0) {
182             SHARING_LOGI("connect immediately.");
183             return true;
184         } else {
185             if (errno == EINPROGRESS) {
186                 SHARING_LOGI("connecting.");
187                 return SocketUtils::CheckAsyncConnect(fd);
188             } else {
189                 return false;
190             }
191         }
192     } else {
193         if (res == 0) {
194             return true;
195         } else {
196             return false;
197         }
198     }
199 }
200 
CheckAsyncConnect(int32_t fd)201 bool SocketUtils::CheckAsyncConnect(int32_t fd)
202 {
203     SHARING_LOGD("trace.");
204     struct timeval timeout;
205     timeout.tv_sec = 2;           // 2: wait +2 second
206     timeout.tv_usec = 500 * 1000; // 500 * 1000: wait +0.5 second
207 
208     fd_set fdr;
209     fd_set fdw;
210     FD_ZERO(&fdr);
211     FD_ZERO(&fdw);
212     FD_SET(fd, &fdr);
213     FD_SET(fd, &fdw);
214 
215     int32_t rc = select(fd + 1, &fdr, &fdw, nullptr, &timeout);
216     if (rc == 1 && FD_ISSET(fd, &fdw)) {
217         SHARING_LOGI("async connect success\n");
218         return true;
219     }
220 
221     if (rc == 0) {
222         SHARING_LOGE("async connect timeout.");
223     }
224 
225     if ((rc < 0) || (rc == 2)) { // 2: select error
226         SHARING_LOGE("async connect error: %{public}s!", strerror(errno));
227     }
228 
229     return false;
230 }
231 
ShutDownSocket(int32_t fd)232 void SocketUtils::ShutDownSocket(int32_t fd)
233 {
234     SHARING_LOGD("trace.");
235     if (fd >= 0) {
236         SHARING_LOGD("shutdown fd: %{public}d.", fd);
237         shutdown(fd, SHUT_RDWR);
238     }
239 }
240 
SetReuseAddr(int32_t fd, bool isReuse)241 bool SocketUtils::SetReuseAddr(int32_t fd, bool isReuse)
242 {
243     SHARING_LOGD("trace.");
244     int32_t on = isReuse ? 1 : 0;
245     if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) != 0) {
246         SHARING_LOGE("error: %{public}s!", strerror(errno));
247         return false;
248     }
249 
250     return true;
251 }
252 
SetReusePort(int32_t fd, bool isReuse)253 bool SocketUtils::SetReusePort(int32_t fd, bool isReuse)
254 {
255     SHARING_LOGD("trace.");
256     int32_t on = isReuse ? 1 : 0;
257     if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)) != 0) {
258         SHARING_LOGE("error: %{public}s!", strerror(errno));
259         return false;
260     }
261 
262     return true;
263 }
264 
SetCloseWait(int32_t fd, int32_t second)265 bool SocketUtils::SetCloseWait(int32_t fd, int32_t second)
266 {
267     SHARING_LOGD("trace.");
268     linger sLinger;
269     sLinger.l_onoff = (second > 0);
270     sLinger.l_linger = second;
271     if (setsockopt(fd, SOL_SOCKET, SO_LINGER, &sLinger, sizeof(linger)) == -1) {
272         SHARING_LOGE("error: %{public}s!", strerror(errno));
273         return false;
274     }
275 
276     return true;
277 }
278 
SetCloExec(int32_t fd, bool isOn)279 bool SocketUtils::SetCloExec(int32_t fd, bool isOn)
280 {
281     SHARING_LOGD("trace.");
282 
283     int32_t flags = fcntl(fd, F_GETFD);
284     if (flags == -1) {
285         SHARING_LOGE("fcntl error: %{public}s!", strerror(errno));
286         return false;
287     }
288 
289     if (isOn) {
290         flags |= FD_CLOEXEC;
291     } else {
292         int32_t cloexec = FD_CLOEXEC;
293         flags &= ~cloexec;
294     }
295 
296     if (fcntl(fd, F_SETFD, flags) == -1) {
297         SHARING_LOGE("error: %{public}s!", strerror(errno));
298         return false;
299     }
300 
301     return true;
302 }
303 
SetNoDelay(int32_t fd, bool isOn)304 bool SocketUtils::SetNoDelay(int32_t fd, bool isOn)
305 {
306     SHARING_LOGD("trace.");
307     int32_t on = isOn;
308     if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) != 0) {
309         SHARING_LOGE("error: %{public}s!", strerror(errno));
310         return false;
311     }
312 
313     return true;
314 }
315 
SetKeepAlive(int32_t sockfd)316 void SocketUtils::SetKeepAlive(int32_t sockfd)
317 {
318     SHARING_LOGD("trace.");
319     int32_t on = 1;
320     setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&on), sizeof(on));
321 }
322 
SetNonBlocking(int32_t fd, bool isNonBlock, uint32_t writeTimeout)323 bool SocketUtils::SetNonBlocking(int32_t fd, bool isNonBlock, uint32_t writeTimeout)
324 {
325     SHARING_LOGD("trace.");
326     int32_t flags = -1;
327     if (isNonBlock) {
328         flags = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) | O_NONBLOCK);
329     } else {
330         flags = fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) & ~O_NONBLOCK);
331         if (writeTimeout > 0) {
332             struct timeval tv = {writeTimeout / 1000, (writeTimeout % 1000) * 1000};
333             setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<char *>(&tv), sizeof tv);
334         }
335     }
336 
337     return (flags != -1) ? true : false;
338 }
339 
SetSendBuf(int32_t fd, int32_t size)340 bool SocketUtils::SetSendBuf(int32_t fd, int32_t size)
341 {
342     SHARING_LOGD("trace.");
343     if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)) != 0) {
344         SHARING_LOGE("error: %{public}s!", strerror(errno));
345         return false;
346     }
347 
348     return true;
349 }
350 
SetRecvBuf(int32_t fd, int32_t size)351 bool SocketUtils::SetRecvBuf(int32_t fd, int32_t size)
352 {
353     SHARING_LOGD("trace.");
354     if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) != 0) {
355         SHARING_LOGE("error: %{public}s!", strerror(errno));
356         return false;
357     }
358 
359     return true;
360 }
361 
CloseSocket(int32_t fd)362 void SocketUtils::CloseSocket(int32_t fd)
363 {
364     SHARING_LOGD("trace.");
365     if (fd >= 0) {
366         SHARING_LOGD("close fd: %{public}d.", fd);
367         close(fd);
368     }
369 }
370 
SendSocket(int32_t fd, const char *buf, int32_t len)371 int32_t SocketUtils::SendSocket(int32_t fd, const char *buf, int32_t len)
372 {
373     SHARING_LOGD("trace.");
374     if (fd < 0 || buf == nullptr || len == 0) {
375         return -1;
376     }
377     SHARING_LOGD("sendSocket: \r\n%{public}s.", buf);
378     int32_t bytes = 0;
379     bool sending = true;
380     while (sending) {
381         if (bytes >= len || bytes < 0) {
382             sending = false;
383             break;
384         }
385 
386         int32_t retCode = send(fd, &buf[bytes], len - bytes, 0);
387         if ((retCode < 0) && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {
388             SHARING_LOGD("sendSocket: continue.");
389             continue;
390         } else if (retCode > 0) {
391             bytes += retCode;
392             if (bytes == len) {
393                 sending = false;
394                 break;
395             }
396         } else {
397             SHARING_LOGE("error: %{public}s!", strerror(errno));
398             bytes = 0;
399             sending = false;
400             break;
401         }
402     }
403 
404     SHARING_LOGD("finish! fd: %{public}d size: %{public}d.", fd, bytes);
405     return bytes;
406 }
407 
Sendto(int32_t fd, const char *buf, size_t len, const char *ip, int32_t nPort)408 int32_t SocketUtils::Sendto(int32_t fd, const char *buf, size_t len, const char *ip, int32_t nPort)
409 {
410     RETURN_INVALID_IF_NULL(buf);
411     RETURN_INVALID_IF_NULL(ip);
412     SHARING_LOGD("trace: \r\n%{public}s.", buf);
413     struct sockaddr_in addr = {};
414     addr.sin_family = AF_INET;
415     addr.sin_port = htons(nPort);
416     if (inet_pton(AF_INET, ip, &addr.sin_addr) <= 0) {
417         SHARING_LOGE("inet_pton error: %{public}s!", strerror(errno));
418         return -1;
419     }
420     int32_t retCode = sendto(fd, buf, len, 0, (struct sockaddr *)&addr, sizeof(addr));
421     if (retCode < 0 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {
422     } else if (retCode > 0) {
423     } else {
424         SHARING_LOGE("sendto error: %{public}s!", strerror(errno));
425         retCode = 0;
426     }
427 
428     return retCode;
429 }
430 
ReadSocket(int32_t fd, char *buf, uint32_t len, int32_t &error)431 int32_t SocketUtils::ReadSocket(int32_t fd, char *buf, uint32_t len, int32_t &error)
432 {
433     SHARING_LOGD("trace.");
434     if (fd < 0 || buf == nullptr || len == 0) {
435         SHARING_LOGE("invalid param!");
436         return -1;
437     }
438 
439     int32_t retCode = read(fd, buf, len);
440     error = errno;
441     if (retCode < 0 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {
442     } else if (retCode > 0) {
443     } else {
444         SHARING_LOGE("error: %{public}s!", strerror(errno));
445         retCode = 0;
446     }
447 
448     return retCode;
449 }
450 
ReadSocket(int32_t fd, DataBuffer::Ptr buf, int32_t &error)451 int32_t SocketUtils::ReadSocket(int32_t fd, DataBuffer::Ptr buf, int32_t &error)
452 {
453     SHARING_LOGD("trace.");
454     if (fd < 0 || !buf) {
455         SHARING_LOGE("readSocket:invalid param!");
456         return -1;
457     }
458 
459     auto size = buf->Capacity() - buf->Size();
460     if (size < READ_BUF_SIZE) {
461         uint32_t bufferReaderSize = buf->Size();
462         if (bufferReaderSize > MAX_READ_BUF_SIZE) {
463             SHARING_LOGE("error data size!");
464             return -1;
465         }
466         buf->Resize(bufferReaderSize + READ_BUF_SIZE);
467     }
468 
469     int32_t bytesRead = read(fd, buf->Data() + buf->Size(), READ_BUF_SIZE);
470     error = errno;
471     if (bytesRead > 0) {
472         buf->UpdateSize(buf->Size() + bytesRead);
473     } else if (bytesRead < 0 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {
474     } else {
475         bytesRead = 0;
476         SHARING_LOGE("error: %{public}s!", strerror(errno));
477     }
478 
479     return bytesRead;
480 }
481 
RecvSocket(int32_t fd, char *buf, uint32_t len, int32_t flags, int32_t &error)482 int32_t SocketUtils::RecvSocket(int32_t fd, char *buf, uint32_t len, int32_t flags, int32_t &error)
483 {
484     SHARING_LOGD("trace.");
485     if (fd < 0 || buf == nullptr || len == 0) {
486         SHARING_LOGE("invalid param.");
487         return -1;
488     }
489 
490     int32_t retCode = recv(fd, buf, len, flags);
491     error = errno;
492     if (retCode < 0 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) {
493     } else if (retCode > 0) {
494     } else {
495         SHARING_LOGE("error: %{public}s!", strerror(errno));
496         retCode = 0;
497     }
498 
499     return retCode;
500 }
501 
AcceptSocket(int32_t fd, struct sockaddr_in *clientAddr, socklen_t *addrLen)502 int32_t SocketUtils::AcceptSocket(int32_t fd, struct sockaddr_in *clientAddr, socklen_t *addrLen)
503 {
504     SHARING_LOGD("trace.");
505     RETURN_INVALID_IF_NULL(clientAddr);
506     RETURN_INVALID_IF_NULL(addrLen);
507     int32_t clientFd = accept(fd, reinterpret_cast<struct sockaddr *>(clientAddr), addrLen);
508     if (clientFd < 0) {
509         SHARING_LOGE("accept error: %{public}s!", strerror(errno));
510     }
511 
512     return clientFd;
513 }
514 
GetIpPortInfo(int32_t fd, std::string &strLocalAddr, std::string &strRemoteAddr, uint16_t &localPort, uint16_t &remotePort)515 bool SocketUtils::GetIpPortInfo(int32_t fd, std::string &strLocalAddr, std::string &strRemoteAddr, uint16_t &localPort,
516                                 uint16_t &remotePort)
517 {
518     SHARING_LOGD("trace.");
519     struct sockaddr_in localAddr;
520     socklen_t localAddrLen = sizeof(localAddr);
521     if (-1 == getsockname(fd, (struct sockaddr *)&localAddr, &localAddrLen)) {
522         SHARING_LOGE("getsockname error: %{public}s!", strerror(errno));
523         return false;
524     }
525 
526     struct sockaddr_in remoteAddr;
527     socklen_t remoteAddrLen = sizeof(remoteAddr);
528     if (-1 == getpeername(fd, (struct sockaddr *)&remoteAddr, &remoteAddrLen)) {
529         SHARING_LOGE("getpeername error: %{public}s!", strerror(errno));
530         return false;
531     }
532 
533     strLocalAddr = inet_ntoa(localAddr.sin_addr);
534     strRemoteAddr = inet_ntoa(remoteAddr.sin_addr);
535 
536     localPort = ntohs(localAddr.sin_port);
537     remotePort = ntohs(remoteAddr.sin_port);
538     SHARING_LOGD("localAddr: %{public}s localPort: %{public}d remoteAddr: %{public}s remotePort: %{public}d",
539                  GetAnonyString(strLocalAddr).c_str(), localPort, GetAnonyString(strRemoteAddr).c_str(), remotePort);
540     return true;
541 }
542 } // namespace Sharing
543 } // namespace OHOS