1 /*
2  * Copyright (C) 2021-2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "server.h"
17 #include <fcntl.h>
18 #include <pthread.h>
19 #include <unistd.h>
20 #include <stdlib.h>
21 #include <sys/epoll.h>
22 #include <sys/socket.h>
23 #include <sys/un.h>
24 #include "common.h"
25 #include "log.h"
26 #include "net.h"
27 
28 #undef LOG_TAG
29 #define LOG_TAG "WifiRpcServer"
30 
31 const int DEFAULT_LISTEN_QUEUE_SIZE = 10;
32 const int MAX_SUPPORT_CLIENT_FD_SIZE = 256; /* support max clients online */
33 const int DEFAULT_HASHTABLE_SLOTS = 7;
34 const int SERIAL_DATA_HEAD_SIZE = 2; /* RPC message head size: N| / C| just 2 */
35 
36 static int BeforeLoop(RpcServer *server);
37 static int RemoveCallback(RpcServer *server, const Context *context);
38 
OnAccept(RpcServer *server, unsigned int mask)39 static int OnAccept(RpcServer *server, unsigned int mask)
40 {
41     if (server == NULL) {
42         return -1;
43     }
44 
45     if ((mask & READ_EVENT) == 0) {
46         return 0;
47     }
48     int fd = accept(server->listenFd, NULL, NULL);
49     if (fd < 0) {
50         return -1;
51     }
52     if (SetNonBlock(fd, 1) != 0) {
53         LOGE("OnAccept  SetNonBlock failed!");
54         close(fd);
55         return -1;
56     }
57     if (fcntl(fd, F_SETFD, FD_CLOEXEC) == -1) {
58         LOGE("OnAccept  fcntl failed!");
59         close(fd);
60         return -1;
61     }
62     Context *context = CreateContext(CONTEXT_BUFFER_MIN_SIZE);
63     if (context != NULL) {
64         context->fd = fd;
65         InsertHashTable(server->clients, context);
66         AddFdEvent(server->loop, fd, READ_EVENT | WRIT_EVENT);
67     } else {
68         close(fd);
69         LOGE("Init Client context failed!");
70         return -1;
71     }
72     return 0;
73 }
74 
CreateRpcServer(const char *path)75 RpcServer *CreateRpcServer(const char *path)
76 {
77     if (path == NULL) {
78         return NULL;
79     }
80     RpcServer *server = (RpcServer *)calloc(1, sizeof(RpcServer));
81     if (server == NULL) {
82         return NULL;
83     }
84     int flag = 1;
85     do {
86         int ret = CreateUnixServer(path, DEFAULT_LISTEN_QUEUE_SIZE);
87         if (ret < 0) {
88             break;
89         }
90         server->listenFd = ret;
91         server->isHandlingMsg = false;
92         server->loop = CreateEventLoop(MAX_SUPPORT_CLIENT_FD_SIZE);
93         if (server->loop == NULL) {
94             break;
95         }
96         server->clients = InitHashTable(DEFAULT_HASHTABLE_SLOTS);
97         if (server->clients == NULL) {
98             break;
99         }
100         if (AddFdEvent(server->loop, server->listenFd, READ_EVENT) < 0) {
101             break;
102         }
103         pthread_mutex_init(&server->mutex, NULL);
104         flag = 0;
105     } while (0);
106     if (flag) {
107         ReleaseRpcServer(server);
108         return NULL;
109     }
110     return server;
111 }
112 
DealReadMessage(RpcServer *server, Context *client)113 static int DealReadMessage(RpcServer *server, Context *client)
114 {
115     if ((server == NULL) || (client == NULL)) {
116         return 0;
117     }
118     char *buf = ContextGetReadRecord(client);
119     if (buf == NULL) {
120         return 0;
121     }
122     client->oneProcess = buf;
123     client->nPos = SERIAL_DATA_HEAD_SIZE; /* N| */
124     client->nSize = strlen(buf);
125     OnTransact(server, client);
126     free(buf);
127     buf = NULL;
128     AddFdEvent(server->loop, client->fd, WRIT_EVENT);
129     return 1;
130 }
131 
CheckEventMask(const struct epoll_event *e)132 static unsigned int CheckEventMask(const struct epoll_event *e)
133 {
134     if (e == NULL) {
135         return 0;
136     }
137     unsigned int mask = NONE_EVENT;
138     if ((e->events & EPOLLERR) || (e->events & EPOLLHUP)) {
139         mask |= READ_EVENT | WRIT_EVENT | EXCP_EVENT;
140     } else {
141         if (e->events & EPOLLIN) {
142             mask |= READ_EVENT;
143         }
144         if (e->events & EPOLLOUT) {
145             mask |= WRIT_EVENT;
146         }
147     }
148     return mask;
149 }
150 
DealFdReadEvent(RpcServer *server, Context *client, unsigned int mask)151 static void DealFdReadEvent(RpcServer *server, Context *client, unsigned int mask)
152 {
153     if ((server == NULL) || (client == NULL)) {
154         return;
155     }
156     DealReadMessage(server, client);
157     int ret = ContextReadNet(client);
158     if ((ret == SOCK_ERR) || ((ret == SOCK_CLOSE) && (mask & EXCP_EVENT))) {
159         LOGE("ContextReadNet failed: %{public}d", ret);
160         DelFdEvent(server->loop, client->fd, READ_EVENT | WRIT_EVENT);
161     } else if (ret == SOCK_CLOSE) {
162         LOGE("Socket close.");
163         DelFdEvent(server->loop, client->fd, READ_EVENT);
164     } else if (ret > 0) {
165         int haveMsg;
166         do {
167             haveMsg = DealReadMessage(server, client);
168         } while (haveMsg);
169     }
170     return;
171 }
172 
DealFdWriteEvent(RpcServer *server, Context *client)173 static void DealFdWriteEvent(RpcServer *server, Context *client)
174 {
175     if ((server == NULL) || (client == NULL)) {
176         return;
177     }
178 
179     if (client->wBegin != client->wEnd) {
180         int tmp = ContextWriteNet(client);
181         if (tmp < 0) {
182             LOGE("ContextWriteNet failed: %{public}d", tmp);
183             DelFdEvent(server->loop, client->fd, READ_EVENT | WRIT_EVENT);
184         }
185     } else {
186         LOGE("Del write event.");
187         DelFdEvent(server->loop, client->fd, WRIT_EVENT);
188     }
189     return;
190 }
191 
DealFdEvents(RpcServer *server, int fd, unsigned int mask)192 static void DealFdEvents(RpcServer *server, int fd, unsigned int mask)
193 {
194     if (server == NULL) {
195         return;
196     }
197     Context *client = FindContext(server->clients, fd);
198     if (client == NULL) {
199         LOGD("not find %{public}d clients!", fd);
200         return;
201     }
202     if (mask & READ_EVENT) {
203         DealFdReadEvent(server, client, mask);
204     }
205     if (mask & WRIT_EVENT) {
206         DealFdWriteEvent(server, client);
207     }
208     if (server->loop->fdMasks[fd].mask == NONE_EVENT) {
209         close(fd);
210         DeleteHashTable(server->clients, client);
211         RemoveCallback(server, client);
212         ReleaseContext(client);
213     }
214     return;
215 }
216 
RunRpcLoop(RpcServer *server)217 int RunRpcLoop(RpcServer *server)
218 {
219     if (server == NULL) {
220         return -1;
221     }
222 
223     EventLoop *loop = server->loop;
224     while (!loop->stop) {
225         BeforeLoop(server);
226         server->isHandlingMsg = false;
227         int retval = epoll_wait(loop->epfd, loop->epEvents, loop->setSize, -1);
228         server->isHandlingMsg = true;
229         for (int i = 0; i < retval; ++i) {
230             struct epoll_event *e = loop->epEvents + i;
231             int fd = e->data.fd;
232             unsigned int mask = CheckEventMask(e);
233             if (fd == server->listenFd) {
234                 OnAccept(server, mask);
235             } else {
236                 DealFdEvents(server, fd, mask);
237             }
238         }
239     }
240     return 0;
241 }
242 
ReleaseRpcServer(RpcServer *server)243 void ReleaseRpcServer(RpcServer *server)
244 {
245     if (server != NULL) {
246         if (server->clients != NULL) {
247             DestroyHashTable(server->clients);
248         }
249         if (server->loop != NULL) {
250             DestroyEventLoop(server->loop);
251         }
252         if (server->listenFd > 0) {
253             close(server->listenFd);
254         }
255         pthread_mutex_destroy(&server->mutex);
256         free(server);
257         server = NULL;
258     }
259 }
260 
BeforeLoop(RpcServer *server)261 static int BeforeLoop(RpcServer *server)
262 {
263     if (server == NULL) {
264         return -1;
265     }
266     pthread_mutex_lock(&server->mutex);
267     for (int i = 0; i < server->nEvents; ++i) {
268         int event = server->events[i];
269         uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
270         int pos = event % num;
271         struct Node *p = server->eventNode[pos].head;
272         while (p != NULL) {
273             Context *context = p->context;
274             OnCallbackTransact(server, event, context);
275             AddFdEvent(server->loop, context->fd, WRIT_EVENT);
276             p = p->next;
277         }
278         EndCallbackTransact(server, event);
279     }
280     server->nEvents = 0;
281     pthread_mutex_unlock(&server->mutex);
282     return 0;
283 }
284 
EmitEvent(RpcServer *server, int event)285 int EmitEvent(RpcServer *server, int event)
286 {
287     if (server == NULL) {
288         return -1;
289     }
290     int num = sizeof(server->events) / sizeof(server->events[0]);
291     pthread_mutex_lock(&server->mutex);
292     if (server->nEvents >= num) {
293         pthread_mutex_unlock(&server->mutex);
294         return -1;
295     }
296     server->events[server->nEvents] = event;
297     ++server->nEvents;
298     pthread_mutex_unlock(&server->mutex);
299     /* Triger write to socket */
300     if (server->isHandlingMsg == false) {
301         BeforeLoop(server);
302     }
303     return 0;
304 }
305 
RegisterCallback(RpcServer *server, int event, Context *context)306 int RegisterCallback(RpcServer *server, int event, Context *context)
307 {
308     if ((server == NULL) || (context == NULL)) {
309         return -1;
310     }
311 
312     uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
313     int pos = event % num;
314     if (pos >= MAX_EVENT_NODE_COUNT) {
315         return -1;
316     }
317     server->eventNode[pos].event = event;
318     struct Node *p = server->eventNode[pos].head;
319     while (p != NULL && p->context->fd != context->fd) {
320         p = p->next;
321     }
322     if (p == NULL) {
323         p = (struct Node *)calloc(1, sizeof(struct Node));
324         if (p != NULL) {
325             p->next = server->eventNode[pos].head;
326             p->context = context;
327             server->eventNode[pos].head = p;
328         }
329     }
330     return 0;
331 }
332 
UnRegisterCallback(RpcServer *server, int event, const Context *context)333 int UnRegisterCallback(RpcServer *server, int event, const Context *context)
334 {
335     if ((server == NULL) || (context == NULL)) {
336         return -1;
337     }
338 
339     uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
340     int pos = event % num;
341     if (pos >= MAX_EVENT_NODE_COUNT) {
342         return -1;
343     }
344     server->eventNode[pos].event = event;
345     struct Node *p = server->eventNode[pos].head;
346     struct Node *q = p;
347     while (p != NULL && p->context->fd != context->fd) {
348         q = p;
349         p = p->next;
350     }
351     if (p != NULL) {
352         if (p == server->eventNode[pos].head) {
353             server->eventNode[pos].head = p->next;
354         } else {
355             q->next = p->next;
356         }
357         free(p);
358         p = NULL;
359     }
360     return 0;
361 }
362 
RemoveCallback(RpcServer *server, const Context *context)363 static int RemoveCallback(RpcServer *server, const Context *context)
364 {
365     if ((server == NULL) || (context == NULL)) {
366         return -1;
367     }
368 
369     uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
370     for (int i = 0; i < num; ++i) {
371         struct Node *p = server->eventNode[i].head;
372         if (p == NULL) {
373             continue;
374         }
375         struct Node *q = p;
376         while (p != NULL && p->context->fd != context->fd) {
377             q = p;
378             p = p->next;
379         }
380         if (p != NULL) {
381             if (p == server->eventNode[i].head) {
382                 server->eventNode[i].head = p->next;
383             } else {
384                 q->next = p->next;
385             }
386             free(p);
387             p = NULL;
388         }
389     }
390     return 0;
391 }
392