1/*
2 * Copyright (C) 2021-2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "server.h"
17#include <fcntl.h>
18#include <pthread.h>
19#include <unistd.h>
20#include <stdlib.h>
21#include <sys/epoll.h>
22#include <sys/socket.h>
23#include <sys/un.h>
24#include "common.h"
25#include "log.h"
26#include "net.h"
27
28#undef LOG_TAG
29#define LOG_TAG "WifiRpcServer"
30
31const int DEFAULT_LISTEN_QUEUE_SIZE = 10;
32const int MAX_SUPPORT_CLIENT_FD_SIZE = 256; /* support max clients online */
33const int DEFAULT_HASHTABLE_SLOTS = 7;
34const int SERIAL_DATA_HEAD_SIZE = 2; /* RPC message head size: N| / C| just 2 */
35
36static int BeforeLoop(RpcServer *server);
37static int RemoveCallback(RpcServer *server, const Context *context);
38
39static int OnAccept(RpcServer *server, unsigned int mask)
40{
41    if (server == NULL) {
42        return -1;
43    }
44
45    if ((mask & READ_EVENT) == 0) {
46        return 0;
47    }
48    int fd = accept(server->listenFd, NULL, NULL);
49    if (fd < 0) {
50        return -1;
51    }
52    if (SetNonBlock(fd, 1) != 0) {
53        LOGE("OnAccept  SetNonBlock failed!");
54        close(fd);
55        return -1;
56    }
57    if (fcntl(fd, F_SETFD, FD_CLOEXEC) == -1) {
58        LOGE("OnAccept  fcntl failed!");
59        close(fd);
60        return -1;
61    }
62    Context *context = CreateContext(CONTEXT_BUFFER_MIN_SIZE);
63    if (context != NULL) {
64        context->fd = fd;
65        InsertHashTable(server->clients, context);
66        AddFdEvent(server->loop, fd, READ_EVENT | WRIT_EVENT);
67    } else {
68        close(fd);
69        LOGE("Init Client context failed!");
70        return -1;
71    }
72    return 0;
73}
74
75RpcServer *CreateRpcServer(const char *path)
76{
77    if (path == NULL) {
78        return NULL;
79    }
80    RpcServer *server = (RpcServer *)calloc(1, sizeof(RpcServer));
81    if (server == NULL) {
82        return NULL;
83    }
84    int flag = 1;
85    do {
86        int ret = CreateUnixServer(path, DEFAULT_LISTEN_QUEUE_SIZE);
87        if (ret < 0) {
88            break;
89        }
90        server->listenFd = ret;
91        server->isHandlingMsg = false;
92        server->loop = CreateEventLoop(MAX_SUPPORT_CLIENT_FD_SIZE);
93        if (server->loop == NULL) {
94            break;
95        }
96        server->clients = InitHashTable(DEFAULT_HASHTABLE_SLOTS);
97        if (server->clients == NULL) {
98            break;
99        }
100        if (AddFdEvent(server->loop, server->listenFd, READ_EVENT) < 0) {
101            break;
102        }
103        pthread_mutex_init(&server->mutex, NULL);
104        flag = 0;
105    } while (0);
106    if (flag) {
107        ReleaseRpcServer(server);
108        return NULL;
109    }
110    return server;
111}
112
113static int DealReadMessage(RpcServer *server, Context *client)
114{
115    if ((server == NULL) || (client == NULL)) {
116        return 0;
117    }
118    char *buf = ContextGetReadRecord(client);
119    if (buf == NULL) {
120        return 0;
121    }
122    client->oneProcess = buf;
123    client->nPos = SERIAL_DATA_HEAD_SIZE; /* N| */
124    client->nSize = strlen(buf);
125    OnTransact(server, client);
126    free(buf);
127    buf = NULL;
128    AddFdEvent(server->loop, client->fd, WRIT_EVENT);
129    return 1;
130}
131
132static unsigned int CheckEventMask(const struct epoll_event *e)
133{
134    if (e == NULL) {
135        return 0;
136    }
137    unsigned int mask = NONE_EVENT;
138    if ((e->events & EPOLLERR) || (e->events & EPOLLHUP)) {
139        mask |= READ_EVENT | WRIT_EVENT | EXCP_EVENT;
140    } else {
141        if (e->events & EPOLLIN) {
142            mask |= READ_EVENT;
143        }
144        if (e->events & EPOLLOUT) {
145            mask |= WRIT_EVENT;
146        }
147    }
148    return mask;
149}
150
151static void DealFdReadEvent(RpcServer *server, Context *client, unsigned int mask)
152{
153    if ((server == NULL) || (client == NULL)) {
154        return;
155    }
156    DealReadMessage(server, client);
157    int ret = ContextReadNet(client);
158    if ((ret == SOCK_ERR) || ((ret == SOCK_CLOSE) && (mask & EXCP_EVENT))) {
159        LOGE("ContextReadNet failed: %{public}d", ret);
160        DelFdEvent(server->loop, client->fd, READ_EVENT | WRIT_EVENT);
161    } else if (ret == SOCK_CLOSE) {
162        LOGE("Socket close.");
163        DelFdEvent(server->loop, client->fd, READ_EVENT);
164    } else if (ret > 0) {
165        int haveMsg;
166        do {
167            haveMsg = DealReadMessage(server, client);
168        } while (haveMsg);
169    }
170    return;
171}
172
173static void DealFdWriteEvent(RpcServer *server, Context *client)
174{
175    if ((server == NULL) || (client == NULL)) {
176        return;
177    }
178
179    if (client->wBegin != client->wEnd) {
180        int tmp = ContextWriteNet(client);
181        if (tmp < 0) {
182            LOGE("ContextWriteNet failed: %{public}d", tmp);
183            DelFdEvent(server->loop, client->fd, READ_EVENT | WRIT_EVENT);
184        }
185    } else {
186        LOGE("Del write event.");
187        DelFdEvent(server->loop, client->fd, WRIT_EVENT);
188    }
189    return;
190}
191
192static void DealFdEvents(RpcServer *server, int fd, unsigned int mask)
193{
194    if (server == NULL) {
195        return;
196    }
197    Context *client = FindContext(server->clients, fd);
198    if (client == NULL) {
199        LOGD("not find %{public}d clients!", fd);
200        return;
201    }
202    if (mask & READ_EVENT) {
203        DealFdReadEvent(server, client, mask);
204    }
205    if (mask & WRIT_EVENT) {
206        DealFdWriteEvent(server, client);
207    }
208    if (server->loop->fdMasks[fd].mask == NONE_EVENT) {
209        close(fd);
210        DeleteHashTable(server->clients, client);
211        RemoveCallback(server, client);
212        ReleaseContext(client);
213    }
214    return;
215}
216
217int RunRpcLoop(RpcServer *server)
218{
219    if (server == NULL) {
220        return -1;
221    }
222
223    EventLoop *loop = server->loop;
224    while (!loop->stop) {
225        BeforeLoop(server);
226        server->isHandlingMsg = false;
227        int retval = epoll_wait(loop->epfd, loop->epEvents, loop->setSize, -1);
228        server->isHandlingMsg = true;
229        for (int i = 0; i < retval; ++i) {
230            struct epoll_event *e = loop->epEvents + i;
231            int fd = e->data.fd;
232            unsigned int mask = CheckEventMask(e);
233            if (fd == server->listenFd) {
234                OnAccept(server, mask);
235            } else {
236                DealFdEvents(server, fd, mask);
237            }
238        }
239    }
240    return 0;
241}
242
243void ReleaseRpcServer(RpcServer *server)
244{
245    if (server != NULL) {
246        if (server->clients != NULL) {
247            DestroyHashTable(server->clients);
248        }
249        if (server->loop != NULL) {
250            DestroyEventLoop(server->loop);
251        }
252        if (server->listenFd > 0) {
253            close(server->listenFd);
254        }
255        pthread_mutex_destroy(&server->mutex);
256        free(server);
257        server = NULL;
258    }
259}
260
261static int BeforeLoop(RpcServer *server)
262{
263    if (server == NULL) {
264        return -1;
265    }
266    pthread_mutex_lock(&server->mutex);
267    for (int i = 0; i < server->nEvents; ++i) {
268        int event = server->events[i];
269        uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
270        int pos = event % num;
271        struct Node *p = server->eventNode[pos].head;
272        while (p != NULL) {
273            Context *context = p->context;
274            OnCallbackTransact(server, event, context);
275            AddFdEvent(server->loop, context->fd, WRIT_EVENT);
276            p = p->next;
277        }
278        EndCallbackTransact(server, event);
279    }
280    server->nEvents = 0;
281    pthread_mutex_unlock(&server->mutex);
282    return 0;
283}
284
285int EmitEvent(RpcServer *server, int event)
286{
287    if (server == NULL) {
288        return -1;
289    }
290    int num = sizeof(server->events) / sizeof(server->events[0]);
291    pthread_mutex_lock(&server->mutex);
292    if (server->nEvents >= num) {
293        pthread_mutex_unlock(&server->mutex);
294        return -1;
295    }
296    server->events[server->nEvents] = event;
297    ++server->nEvents;
298    pthread_mutex_unlock(&server->mutex);
299    /* Triger write to socket */
300    if (server->isHandlingMsg == false) {
301        BeforeLoop(server);
302    }
303    return 0;
304}
305
306int RegisterCallback(RpcServer *server, int event, Context *context)
307{
308    if ((server == NULL) || (context == NULL)) {
309        return -1;
310    }
311
312    uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
313    int pos = event % num;
314    if (pos >= MAX_EVENT_NODE_COUNT) {
315        return -1;
316    }
317    server->eventNode[pos].event = event;
318    struct Node *p = server->eventNode[pos].head;
319    while (p != NULL && p->context->fd != context->fd) {
320        p = p->next;
321    }
322    if (p == NULL) {
323        p = (struct Node *)calloc(1, sizeof(struct Node));
324        if (p != NULL) {
325            p->next = server->eventNode[pos].head;
326            p->context = context;
327            server->eventNode[pos].head = p;
328        }
329    }
330    return 0;
331}
332
333int UnRegisterCallback(RpcServer *server, int event, const Context *context)
334{
335    if ((server == NULL) || (context == NULL)) {
336        return -1;
337    }
338
339    uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
340    int pos = event % num;
341    if (pos >= MAX_EVENT_NODE_COUNT) {
342        return -1;
343    }
344    server->eventNode[pos].event = event;
345    struct Node *p = server->eventNode[pos].head;
346    struct Node *q = p;
347    while (p != NULL && p->context->fd != context->fd) {
348        q = p;
349        p = p->next;
350    }
351    if (p != NULL) {
352        if (p == server->eventNode[pos].head) {
353            server->eventNode[pos].head = p->next;
354        } else {
355            q->next = p->next;
356        }
357        free(p);
358        p = NULL;
359    }
360    return 0;
361}
362
363static int RemoveCallback(RpcServer *server, const Context *context)
364{
365    if ((server == NULL) || (context == NULL)) {
366        return -1;
367    }
368
369    uint32_t num = sizeof(server->eventNode) / sizeof(server->eventNode[0]);
370    for (int i = 0; i < num; ++i) {
371        struct Node *p = server->eventNode[i].head;
372        if (p == NULL) {
373            continue;
374        }
375        struct Node *q = p;
376        while (p != NULL && p->context->fd != context->fd) {
377            q = p;
378            p = p->next;
379        }
380        if (p != NULL) {
381            if (p == server->eventNode[i].head) {
382                server->eventNode[i].head = p->next;
383            } else {
384                q->next = p->next;
385            }
386            free(p);
387            p = NULL;
388        }
389    }
390    return 0;
391}
392