1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <errno.h>
17#include <fcntl.h>
18#include <pthread.h>
19#include <string.h>
20#include <unistd.h>
21#include <stdio.h>
22#include <stdlib.h>
23
24#include <linux/in.h>
25#include <linux/socket.h>
26#include <linux/tcp.h>
27#include <sys/socket.h>
28#include <sys/time.h>
29#include <sys/types.h>
30#include <sys/un.h>
31
32#include "loop_systest.h"
33#include "securec.h"
34#include "loop_event.h"
35#include "le_socket.h"
36#include "le_task.h"
37#include "list.h"
38
39#define RETRY_TIME (200 * 1000)     // 200 * 1000 wait 200ms CONNECT_RETRY_DELAY = 200 * 1000
40#define MAX_RETRY_SEND_COUNT 2      // 2 max retry count CONNECT_RETRY_MAX_TIMES = 2;
41
42typedef Agent_ {
43    TaskHandle task;
44    WatcherHandle input;
45    WatcherHandle reader;
46    int ptyfd;
47} Agent;
48
49typedef struct {
50    struct ListNode node;
51    uint32_t blockSize;     // block 的大小
52    uint32_t currentIndex;  // 当前已经填充的位置
53    uint8_t buffer[0];
54} MsgBlock;
55
56typedef struct {
57    uint32_t maxRetryCount;
58    uint32_t timeout;
59    uint32_t msgNextId;
60    int socketId;
61    pthread_mutex_t mutex;
62    MsgBlock recvBlock;  // 消息接收缓存
63} ReqMsgMgr;
64
65static pthread_mutex_t g_mutex = PTHREAD_MUTEX_INITIALIZER;
66static ReqMsgMgr *g_clientInstance = NULL;
67
68Agent *CreateAgent(const char *server, int flags)
69{
70    if (server == NULL) {
71        printf("Invalid parameter \n");
72        return NULL;
73    }
74
75    TaskHandle task = NULL;
76    LE_StreamInfo info = {};
77    info.baseInfo.flags = flags;
78    info.server = (char *)server;
79    info.baseInfo.userDataSize = sizeof(Agent);
80    info.baseInfo.close = OnClose;
81    info.disConnectComplete = DisConnectComplete;
82    info.connectComplete = OnConnectComplete;
83    info.sendMessageComplete = OnSendMessageComplete;
84    info.recvMessage = ClientOnRecvMessage;
85
86    LE_STATUS status = LE_CreateStreamClient(LE_GetDefaultLoop(), &task, &info);
87    if (status != 0) {
88        printf("Failed create client \n");
89        return NULL;
90    }
91    Agent *agent = (Agent *)LE_GetUserData(task);
92    if (agent == NULL) {
93        printf("Invalid agent \n");
94        return NULL;
95    }
96
97    agent->task = task;
98    return agent;
99}
100
101static int InitClientInstance()
102{
103    pthread_mutex_lock(&g_mutex);
104    if (g_clientInstance != NULL) {
105        pthread_mutex_unlock(&g_mutex);
106        return 0;
107    }
108    ReqMsgMgr *clientInstance = malloc(sizeof(ReqMsgMgr) + RECV_BLOCK_LEN);
109    if (clientInstance == NULL) {
110        pthread_mutex_unlock(&g_mutex);
111        return -1;
112    }
113    // init
114    clientInstance->msgNextId = 1;
115    clientInstance->timeout = GetDefaultTimeout(TIMEOUT_DEF);
116    clientInstance->maxRetryCount = MAX_RETRY_SEND_COUNT;
117    clientInstance->socketId = -1;
118    pthread_mutex_init(&clientInstance->mutex, NULL);
119    // init recvBlock
120    OH_ListInit(&clientInstance->recvBlock.node);
121    clientInstance->recvBlock.blockSize = RECV_BLOCK_LEN;
122    clientInstance->recvBlock.currentIndex = 0;
123    g_clientInstance = clientInstance;
124    pthread_mutex_unlock(&g_mutex);
125    return 0;
126}
127
128void ClientInit(const char *socketPath, int flags)
129{
130    if (socketPath == NULL) {
131        printf("Invalid parameter \n");
132    }
133    printf("AgentInit \n");
134    Agent *agent = CreateAgent(socketPath, flags);
135    if (agent == NULL) {
136        printf("Failed to create agent \n");
137        return;
138    }
139
140    printf(" Client exit \n");
141}
142
143int ClientDestroy(ReqMsgMgr *reqMgr)
144{
145    if (reqMgr == NULL) {
146        printf("Invalid reqMgr \n");
147        return -1;
148    }
149
150    pthread_mutex_lock(&g_mutex);
151    pthread_mutex_unlock(&g_mutex);
152    pthread_mutex_destroy(&reqMgr->mutex);
153    if (reqMgr->socketId >= 0) {
154        CloseClientSocket(reqMgr->socketId);
155        reqMgr->socketId = -1;
156    }
157    free(reqMgr);
158    return 0;
159}
160
161static void CloseClientSocket(int socketId)
162{
163    printf("Closed socket with fd %d \n", socketId);
164    if (socketId >= 0) {
165        int flag = 0;
166        setsockopt(socketId, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
167        close(socketId);
168    }
169}
170
171static int CreateClientSocket(uint32_t timeout)
172{
173    const char *socketName = "loopserver";
174    int socketFd = socket(AF_UNIX, SOCK_STREAM, 0);  // SOCK_SEQPACKET
175    if (socketFd < 0) {
176        printf("Socket fd: %d error: %d \n", socketFd, errno);
177    }
178
179    int ret = 0;
180    do {
181        int flag = 1;
182        ret = setsockopt(socketFd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
183        flag = 1;
184        ret = setsockopt(socketFd, SOL_SOCKET, SO_PASSCRED, &flag, sizeof(flag));
185        if (ret != 0) {
186            printf("Set opt SO_PASSCRED name: %s error: %d \n", socketName, errno);
187            break;
188        }
189        struct timeval timeoutVal = {timeout, 0};
190        ret = setsockopt(socketFd, SOL_SOCKET, SO_SNDTIMEO, &timeoutVal, sizeof(timeoutVal));
191        if (ret != 0) {
192            printf("Set opt SO_SNDTIMEO name: %s error: %d \n", socketName, errno);
193            break;
194        }
195        ret = setsockopt(socketFd, SOL_SOCKET, SO_RCVTIMEO, &timeoutVal, sizeof(timeoutVal));
196        if (ret != 0) {
197            printf("Set opt SO_RCVTIMEO name: %s error: %d \n", socketName, errno);
198            break;
199        }
200        ret = _SYSTEM_ERROR;
201        struct sockaddr_un addr;
202        socklen_t pathSize = sizeof(addr.sun_path);
203        int pathLen = snprintf_s(addr.sun_path, pathSize, (pathSize - 1), "%s%s", _SOCKET_DIR, socketName);
204        if (pathLen <= 0) {
205            printf("Format path: %s error: %d \n", socketName, errno);
206            break;
207        }
208        addr.sun_family = AF_LOCAL;
209        socklen_t socketAddrLen = offsetof(struct sockaddr_un, sun_path) + pathLen + 1;
210        ret = connect(socketFd, (struct sockaddr *)(&addr), socketAddrLen);
211        if (ret != 0) {
212            printf("Failed to connect %s error: %d \n", addr.sun_path, errno);
213            break;
214        }
215        printf("Create socket success %s socketFd: %d", addr.sun_path, socketFd);
216        return socketFd;
217    } while (0);
218    CloseClientSocket(socketFd);
219    return -1;
220}
221
222static void TryCreateSocket(ReqMsgMgr *reqMgr)
223{
224    uint32_t retryCount = 1;
225    while (retryCount <= reqMgr->maxRetryCount) {
226        if (reqMgr->socketId < 0) {
227            reqMgr->socketId = CreateClientSocket(reqMgr->timeout);
228        }
229        if (reqMgr->socketId < 0) {
230            printf("Failed to create socket, try again \n");
231            usleep(RETRY_TIME);
232            retryCount++;
233            continue;
234        }
235        break;
236    }
237}
238
239static int WriteMessage(int socketFd, const uint8_t *buf, ssize_t len, int *fds, int *fdCount)
240{
241    ssize_t written = 0;
242    ssize_t remain = len;
243    const uint8_t *offset = buf;
244    struct iovec iov = {
245        .iov_base = (void *) offset,
246        .iov_len = len,
247    };
248    struct msghdr msg = {
249        .msg_iov = &iov,
250        .msg_iovlen = 1,
251    };
252    char *ctrlBuffer = NULL;
253    if (fdCount != NULL && fds != NULL && *fdCount > 0) {
254        msg.msg_controllen = CMSG_SPACE(*fdCount * sizeof(int));
255        ctrlBuffer = (char *) malloc(msg.msg_controllen);
256        if (ctrlBuffer == NULL) {
257            printf("WriteMessage fail to alloc memory for msg_control %d %d", msg.msg_controllen, errno);
258            return -1;
259        }
260        msg.msg_control = ctrlBuffer;
261        struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
262        if (cmsg == NULL) {
263            free(ctrlBuffer);
264            printf("WriteMessage fail to get CMSG_FIRSTHDR %d \n", errno);
265            return -1;
266        }
267        cmsg->cmsg_len = CMSG_LEN(*fdCount * sizeof(int));
268        cmsg->cmsg_type = SCM_RIGHTS;
269        cmsg->cmsg_level = SOL_SOCKET;
270        int ret = memcpy_s(CMSG_DATA(cmsg), cmsg->cmsg_len, fds, *fdCount * sizeof(int));
271        if (ret != 0) {
272            free(ctrlBuffer);
273            printf("WriteMessage fail to memcpy_s fd, errno: %d \n", errno);
274            return -1;
275        }
276        printf("build fd info count %d \n", *fdCount);
277    }
278    for (ssize_t wLen = 0; remain > 0; offset += wLen, remain -= wLen, written += wLen) {
279        errno = 0;
280        wLen = sendmsg(socketFd, &msg, MSG_NOSIGNAL);
281        if ((wLen <= 0) || (errno != EINTR)) {
282            free(ctrlBuffer);
283            printf("Failed to write message to fd %d, wLen %zd errno: %d \n", socketFd, wLen, errno);
284            return -errno;
285        }
286    }
287    free(ctrlBuffer);
288    return written == len ? 0 : -EFAULT;
289}
290
291static int HandleMsgSend(ReqMsgMgr *reqMgr, int socketId, ReqMsgNode *reqNode)
292{
293    printf("HandleMsgSend reqId: %u msgId: %d \n", reqNode->reqId, reqNode->msg->msgId);
294    ListNode *sendNode = reqNode->msgBlocks.next;
295    uint32_t currentIndex = 0;
296    bool sendFd = true;
297    while (sendNode != NULL && sendNode != &reqNode->msgBlocks) {
298        MsgBlock *sendBlock = (MsgBlock *)ListEntry(sendNode, MsgBlock, node);
299        int ret = WriteMessage(socketId, sendBlock->buffer, sendBlock->currentIndex,
300            sendFd ? reqNode->fds : NULL,
301            sendFd ? &reqNode->fdCount : NULL);
302        currentIndex += sendBlock->currentIndex;
303        printf("Write msg ret: %d msgId: %u %u %u \n",
304            ret, reqNode->msg->msgId, reqNode->msg->msgLen, currentIndex);
305        if (ret == 0) {
306            sendFd = false;
307            sendNode = sendNode->next;
308            continue;
309        }
310        printf("Send msg fail reqId: %u msgId: %d ret: %d \n",
311            reqNode->reqId, reqNode->msg->msgId, ret);
312        return ret;
313    }
314    return 0;
315}
316
317static int ReadMessage(int socketFd, uint32_t sendMsgId, uint8_t *buf, int len, Result *result)
318{
319    int rLen = read(socketFd, buf, len);
320    if (rLen < 0) {
321        printf("Read message from fd %d rLen %d errno: %d \n", socketFd, rLen, errno);
322        return TIMEOUT;
323    }
324
325    if ((size_t)rLen >= sizeof(ResponseMsg)) {
326        ResponseMsg *msg = (ResponseMsg *)(buf);
327        if (sendMsgId != msg->msgHdr.msgId) {
328            printf("Invalid msg recvd %u %u \n", sendMsgId, msg->msgHdr.msgId);
329            return memcpy_s(result, sizeof(Result), &msg->result, sizeof(msg->result));
330        }
331    }
332    return TIMEOUT;
333}
334
335static int ClientSendMsg(ReqMsgMgr *reqMgr, ReqMsgNode *reqNode, Result *result)
336{
337    uint32_t retryCount = 1;
338    while (retryCount <= reqMgr->maxRetryCount) {
339        if (reqMgr->socketId < 0) { // try create socket
340            TryCreateSocket(reqMgr);
341            if (reqMgr->socketId < 0) {
342                usleep(RETRY_TIME);
343                retryCount++;
344                continue;
345            }
346        }
347
348        if (reqNode->msg->msgId == 0) {
349            reqNode->msg->msgId = reqMgr->msgNextId++;
350        }
351        int ret = HandleMsgSend(reqMgr, reqMgr->socketId, reqNode);
352        if (ret == 0) {
353            ret = ReadMessage(reqMgr->socketId, reqNode->msg->msgId,
354                reqMgr->recvBlock.buffer, reqMgr->recvBlock.blockSize, result);
355        }
356        if (ret == 0) {
357            return 0;
358        }
359        // retry
360        CloseClientSocket(reqMgr->socketId);
361        reqMgr->socketId = -1;
362        reqMgr->msgNextId = 1;
363        reqNode->msg->msgId = 0;
364        usleep(RETRY_TIME);
365        retryCount++;
366    }
367    return TIMEOUT;
368}
369
370int main(int argc, char *const argv[])
371{
372    printf("main argc: %d \n", argc);
373    if (argc <= 0) {
374        return 0;
375    }
376
377    printf("请输入创建socket的类型:(pipe, tcp)\n");
378    char type[128];
379    int ret = scanf_s("%s", type, sizeof(type));
380    if (ret <= 0) {
381        printf("input error \n");
382        return 0;
383    }
384
385    int flags;
386    char *server;
387    if (strcmp(type, "pipe") == 0) {
388        flags = TASK_STREAM | TASK_PIPE |TASK_SERVER | TASK_TEST;
389        server = (char *)"/data/testpipe";
390    } else if (strcmp(type, "tcp") == 0) {
391        flags = TASK_STREAM | TASK_TCP |TASK_SERVER | TASK_TEST;
392        server = (char *)"127.0.0.1:7777";
393    } else {
394        printf("输入有误,请输入pipe或者tcp!");
395        system("pause");
396        return 0;
397    }
398
399    uint32_t timeout = 200;
400    int fd = CreateClientSocket(timeout);
401    return 0;
402}