1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <errno.h>
17 #include <fcntl.h>
18 #include <pthread.h>
19 #include <string.h>
20 #include <unistd.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 
24 #include <linux/in.h>
25 #include <linux/socket.h>
26 #include <linux/tcp.h>
27 #include <sys/socket.h>
28 #include <sys/time.h>
29 #include <sys/types.h>
30 #include <sys/un.h>
31 
32 #include "loop_systest.h"
33 #include "securec.h"
34 #include "loop_event.h"
35 #include "le_socket.h"
36 #include "le_task.h"
37 #include "list.h"
38 
39 #define RETRY_TIME (200 * 1000)     // 200 * 1000 wait 200ms CONNECT_RETRY_DELAY = 200 * 1000
40 #define MAX_RETRY_SEND_COUNT 2      // 2 max retry count CONNECT_RETRY_MAX_TIMES = 2;
41 
42 typedef Agent_ {
43     TaskHandle task;
44     WatcherHandle input;
45     WatcherHandle reader;
46     int ptyfd;
47 } Agent;
48 
49 typedef struct {
50     struct ListNode node;
51     uint32_t blockSize;     // block 的大小
52     uint32_t currentIndex;  // 当前已经填充的位置
53     uint8_t buffer[0];
54 } MsgBlock;
55 
56 typedef struct {
57     uint32_t maxRetryCount;
58     uint32_t timeout;
59     uint32_t msgNextId;
60     int socketId;
61     pthread_mutex_t mutex;
62     MsgBlock recvBlock;  // 消息接收缓存
63 } ReqMsgMgr;
64 
65 static pthread_mutex_t g_mutex = PTHREAD_MUTEX_INITIALIZER;
66 static ReqMsgMgr *g_clientInstance = NULL;
67 
CreateAgent(const char *server, int flags)68 Agent *CreateAgent(const char *server, int flags)
69 {
70     if (server == NULL) {
71         printf("Invalid parameter \n");
72         return NULL;
73     }
74 
75     TaskHandle task = NULL;
76     LE_StreamInfo info = {};
77     info.baseInfo.flags = flags;
78     info.server = (char *)server;
79     info.baseInfo.userDataSize = sizeof(Agent);
80     info.baseInfo.close = OnClose;
81     info.disConnectComplete = DisConnectComplete;
82     info.connectComplete = OnConnectComplete;
83     info.sendMessageComplete = OnSendMessageComplete;
84     info.recvMessage = ClientOnRecvMessage;
85 
86     LE_STATUS status = LE_CreateStreamClient(LE_GetDefaultLoop(), &task, &info);
87     if (status != 0) {
88         printf("Failed create client \n");
89         return NULL;
90     }
91     Agent *agent = (Agent *)LE_GetUserData(task);
92     if (agent == NULL) {
93         printf("Invalid agent \n");
94         return NULL;
95     }
96 
97     agent->task = task;
98     return agent;
99 }
100 
InitClientInstancenull101 static int InitClientInstance()
102 {
103     pthread_mutex_lock(&g_mutex);
104     if (g_clientInstance != NULL) {
105         pthread_mutex_unlock(&g_mutex);
106         return 0;
107     }
108     ReqMsgMgr *clientInstance = malloc(sizeof(ReqMsgMgr) + RECV_BLOCK_LEN);
109     if (clientInstance == NULL) {
110         pthread_mutex_unlock(&g_mutex);
111         return -1;
112     }
113     // init
114     clientInstance->msgNextId = 1;
115     clientInstance->timeout = GetDefaultTimeout(TIMEOUT_DEF);
116     clientInstance->maxRetryCount = MAX_RETRY_SEND_COUNT;
117     clientInstance->socketId = -1;
118     pthread_mutex_init(&clientInstance->mutex, NULL);
119     // init recvBlock
120     OH_ListInit(&clientInstance->recvBlock.node);
121     clientInstance->recvBlock.blockSize = RECV_BLOCK_LEN;
122     clientInstance->recvBlock.currentIndex = 0;
123     g_clientInstance = clientInstance;
124     pthread_mutex_unlock(&g_mutex);
125     return 0;
126 }
127 
ClientInit(const char *socketPath, int flags)128 void ClientInit(const char *socketPath, int flags)
129 {
130     if (socketPath == NULL) {
131         printf("Invalid parameter \n");
132     }
133     printf("AgentInit \n");
134     Agent *agent = CreateAgent(socketPath, flags);
135     if (agent == NULL) {
136         printf("Failed to create agent \n");
137         return;
138     }
139 
140     printf(" Client exit \n");
141 }
142 
ClientDestroy(ReqMsgMgr *reqMgr)143 int ClientDestroy(ReqMsgMgr *reqMgr)
144 {
145     if (reqMgr == NULL) {
146         printf("Invalid reqMgr \n");
147         return -1;
148     }
149 
150     pthread_mutex_lock(&g_mutex);
151     pthread_mutex_unlock(&g_mutex);
152     pthread_mutex_destroy(&reqMgr->mutex);
153     if (reqMgr->socketId >= 0) {
154         CloseClientSocket(reqMgr->socketId);
155         reqMgr->socketId = -1;
156     }
157     free(reqMgr);
158     return 0;
159 }
160 
CloseClientSocket(int socketId)161 static void CloseClientSocket(int socketId)
162 {
163     printf("Closed socket with fd %d \n", socketId);
164     if (socketId >= 0) {
165         int flag = 0;
166         setsockopt(socketId, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
167         close(socketId);
168     }
169 }
170 
CreateClientSocket(uint32_t timeout)171 static int CreateClientSocket(uint32_t timeout)
172 {
173     const char *socketName = "loopserver";
174     int socketFd = socket(AF_UNIX, SOCK_STREAM, 0);  // SOCK_SEQPACKET
175     if (socketFd < 0) {
176         printf("Socket fd: %d error: %d \n", socketFd, errno);
177     }
178 
179     int ret = 0;
180     do {
181         int flag = 1;
182         ret = setsockopt(socketFd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
183         flag = 1;
184         ret = setsockopt(socketFd, SOL_SOCKET, SO_PASSCRED, &flag, sizeof(flag));
185         if (ret != 0) {
186             printf("Set opt SO_PASSCRED name: %s error: %d \n", socketName, errno);
187             break;
188         }
189         struct timeval timeoutVal = {timeout, 0};
190         ret = setsockopt(socketFd, SOL_SOCKET, SO_SNDTIMEO, &timeoutVal, sizeof(timeoutVal));
191         if (ret != 0) {
192             printf("Set opt SO_SNDTIMEO name: %s error: %d \n", socketName, errno);
193             break;
194         }
195         ret = setsockopt(socketFd, SOL_SOCKET, SO_RCVTIMEO, &timeoutVal, sizeof(timeoutVal));
196         if (ret != 0) {
197             printf("Set opt SO_RCVTIMEO name: %s error: %d \n", socketName, errno);
198             break;
199         }
200         ret = _SYSTEM_ERROR;
201         struct sockaddr_un addr;
202         socklen_t pathSize = sizeof(addr.sun_path);
203         int pathLen = snprintf_s(addr.sun_path, pathSize, (pathSize - 1), "%s%s", _SOCKET_DIR, socketName);
204         if (pathLen <= 0) {
205             printf("Format path: %s error: %d \n", socketName, errno);
206             break;
207         }
208         addr.sun_family = AF_LOCAL;
209         socklen_t socketAddrLen = offsetof(struct sockaddr_un, sun_path) + pathLen + 1;
210         ret = connect(socketFd, (struct sockaddr *)(&addr), socketAddrLen);
211         if (ret != 0) {
212             printf("Failed to connect %s error: %d \n", addr.sun_path, errno);
213             break;
214         }
215         printf("Create socket success %s socketFd: %d", addr.sun_path, socketFd);
216         return socketFd;
217     } while (0);
218     CloseClientSocket(socketFd);
219     return -1;
220 }
221 
TryCreateSocket(ReqMsgMgr *reqMgr)222 static void TryCreateSocket(ReqMsgMgr *reqMgr)
223 {
224     uint32_t retryCount = 1;
225     while (retryCount <= reqMgr->maxRetryCount) {
226         if (reqMgr->socketId < 0) {
227             reqMgr->socketId = CreateClientSocket(reqMgr->timeout);
228         }
229         if (reqMgr->socketId < 0) {
230             printf("Failed to create socket, try again \n");
231             usleep(RETRY_TIME);
232             retryCount++;
233             continue;
234         }
235         break;
236     }
237 }
238 
WriteMessage(int socketFd, const uint8_t *buf, ssize_t len, int *fds, int *fdCount)239 static int WriteMessage(int socketFd, const uint8_t *buf, ssize_t len, int *fds, int *fdCount)
240 {
241     ssize_t written = 0;
242     ssize_t remain = len;
243     const uint8_t *offset = buf;
244     struct iovec iov = {
245         .iov_base = (void *) offset,
246         .iov_len = len,
247     };
248     struct msghdr msg = {
249         .msg_iov = &iov,
250         .msg_iovlen = 1,
251     };
252     char *ctrlBuffer = NULL;
253     if (fdCount != NULL && fds != NULL && *fdCount > 0) {
254         msg.msg_controllen = CMSG_SPACE(*fdCount * sizeof(int));
255         ctrlBuffer = (char *) malloc(msg.msg_controllen);
256         if (ctrlBuffer == NULL) {
257             printf("WriteMessage fail to alloc memory for msg_control %d %d", msg.msg_controllen, errno);
258             return -1;
259         }
260         msg.msg_control = ctrlBuffer;
261         struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
262         if (cmsg == NULL) {
263             free(ctrlBuffer);
264             printf("WriteMessage fail to get CMSG_FIRSTHDR %d \n", errno);
265             return -1;
266         }
267         cmsg->cmsg_len = CMSG_LEN(*fdCount * sizeof(int));
268         cmsg->cmsg_type = SCM_RIGHTS;
269         cmsg->cmsg_level = SOL_SOCKET;
270         int ret = memcpy_s(CMSG_DATA(cmsg), cmsg->cmsg_len, fds, *fdCount * sizeof(int));
271         if (ret != 0) {
272             free(ctrlBuffer);
273             printf("WriteMessage fail to memcpy_s fd, errno: %d \n", errno);
274             return -1;
275         }
276         printf("build fd info count %d \n", *fdCount);
277     }
278     for (ssize_t wLen = 0; remain > 0; offset += wLen, remain -= wLen, written += wLen) {
279         errno = 0;
280         wLen = sendmsg(socketFd, &msg, MSG_NOSIGNAL);
281         if ((wLen <= 0) || (errno != EINTR)) {
282             free(ctrlBuffer);
283             printf("Failed to write message to fd %d, wLen %zd errno: %d \n", socketFd, wLen, errno);
284             return -errno;
285         }
286     }
287     free(ctrlBuffer);
288     return written == len ? 0 : -EFAULT;
289 }
290 
HandleMsgSend(ReqMsgMgr *reqMgr, int socketId, ReqMsgNode *reqNode)291 static int HandleMsgSend(ReqMsgMgr *reqMgr, int socketId, ReqMsgNode *reqNode)
292 {
293     printf("HandleMsgSend reqId: %u msgId: %d \n", reqNode->reqId, reqNode->msg->msgId);
294     ListNode *sendNode = reqNode->msgBlocks.next;
295     uint32_t currentIndex = 0;
296     bool sendFd = true;
297     while (sendNode != NULL && sendNode != &reqNode->msgBlocks) {
298         MsgBlock *sendBlock = (MsgBlock *)ListEntry(sendNode, MsgBlock, node);
299         int ret = WriteMessage(socketId, sendBlock->buffer, sendBlock->currentIndex,
300             sendFd ? reqNode->fds : NULL,
301             sendFd ? &reqNode->fdCount : NULL);
302         currentIndex += sendBlock->currentIndex;
303         printf("Write msg ret: %d msgId: %u %u %u \n",
304             ret, reqNode->msg->msgId, reqNode->msg->msgLen, currentIndex);
305         if (ret == 0) {
306             sendFd = false;
307             sendNode = sendNode->next;
308             continue;
309         }
310         printf("Send msg fail reqId: %u msgId: %d ret: %d \n",
311             reqNode->reqId, reqNode->msg->msgId, ret);
312         return ret;
313     }
314     return 0;
315 }
316 
ReadMessage(int socketFd, uint32_t sendMsgId, uint8_t *buf, int len, Result *result)317 static int ReadMessage(int socketFd, uint32_t sendMsgId, uint8_t *buf, int len, Result *result)
318 {
319     int rLen = read(socketFd, buf, len);
320     if (rLen < 0) {
321         printf("Read message from fd %d rLen %d errno: %d \n", socketFd, rLen, errno);
322         return TIMEOUT;
323     }
324 
325     if ((size_t)rLen >= sizeof(ResponseMsg)) {
326         ResponseMsg *msg = (ResponseMsg *)(buf);
327         if (sendMsgId != msg->msgHdr.msgId) {
328             printf("Invalid msg recvd %u %u \n", sendMsgId, msg->msgHdr.msgId);
329             return memcpy_s(result, sizeof(Result), &msg->result, sizeof(msg->result));
330         }
331     }
332     return TIMEOUT;
333 }
334 
ClientSendMsg(ReqMsgMgr *reqMgr, ReqMsgNode *reqNode, Result *result)335 static int ClientSendMsg(ReqMsgMgr *reqMgr, ReqMsgNode *reqNode, Result *result)
336 {
337     uint32_t retryCount = 1;
338     while (retryCount <= reqMgr->maxRetryCount) {
339         if (reqMgr->socketId < 0) { // try create socket
340             TryCreateSocket(reqMgr);
341             if (reqMgr->socketId < 0) {
342                 usleep(RETRY_TIME);
343                 retryCount++;
344                 continue;
345             }
346         }
347 
348         if (reqNode->msg->msgId == 0) {
349             reqNode->msg->msgId = reqMgr->msgNextId++;
350         }
351         int ret = HandleMsgSend(reqMgr, reqMgr->socketId, reqNode);
352         if (ret == 0) {
353             ret = ReadMessage(reqMgr->socketId, reqNode->msg->msgId,
354                 reqMgr->recvBlock.buffer, reqMgr->recvBlock.blockSize, result);
355         }
356         if (ret == 0) {
357             return 0;
358         }
359         // retry
360         CloseClientSocket(reqMgr->socketId);
361         reqMgr->socketId = -1;
362         reqMgr->msgNextId = 1;
363         reqNode->msg->msgId = 0;
364         usleep(RETRY_TIME);
365         retryCount++;
366     }
367     return TIMEOUT;
368 }
369 
main(int argc, char *const argv[])370 int main(int argc, char *const argv[])
371 {
372     printf("main argc: %d \n", argc);
373     if (argc <= 0) {
374         return 0;
375     }
376 
377     printf("请输入创建socket的类型:(pipe, tcp)\n");
378     char type[128];
379     int ret = scanf_s("%s", type, sizeof(type));
380     if (ret <= 0) {
381         printf("input error \n");
382         return 0;
383     }
384 
385     int flags;
386     char *server;
387     if (strcmp(type, "pipe") == 0) {
388         flags = TASK_STREAM | TASK_PIPE |TASK_SERVER | TASK_TEST;
389         server = (char *)"/data/testpipe";
390     } else if (strcmp(type, "tcp") == 0) {
391         flags = TASK_STREAM | TASK_TCP |TASK_SERVER | TASK_TEST;
392         server = (char *)"127.0.0.1:7777";
393     } else {
394         printf("输入有误,请输入pipe或者tcp!");
395         system("pause");
396         return 0;
397     }
398 
399     uint32_t timeout = 200;
400     int fd = CreateClientSocket(timeout);
401     return 0;
402 }