1/*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <stdbool.h>
17#include <stdio.h>
18#include <stdlib.h>
19#include <string.h>
20
21#include "list.h"
22#include "beget_ext.h"
23#include "hookmgr.h"
24
25// Forward declaration
26typedef struct tagHOOK_STAGE HOOK_STAGE;
27
28/*
29 * Internal HOOK Item with priorities
30 */
31typedef struct tagHOOK_ITEM {
32    ListNode node;
33    HOOK_INFO info;
34    HOOK_STAGE *stage;
35} HOOK_ITEM;
36
37/*
38 * Internal HOOK Stage in the same stage
39 */
40struct tagHOOK_STAGE {
41    ListNode node;
42    int stage;
43    ListNode hooks;
44};
45
46/*
47 * HookManager is consist of different hook stages
48 */
49struct tagHOOK_MGR {
50    const char *name;
51    ListNode stages;
52};
53
54/*
55 * Default HookManager is created automatically for HookMgrAddHook
56 */
57static HOOK_MGR *defaultHookMgr = NULL;
58
59static HOOK_MGR *getHookMgr(HOOK_MGR *hookMgr, int autoCreate)
60{
61    BEGET_CHECK(hookMgr == NULL, return hookMgr);
62    // Use default HOOK_MGR if possible
63    BEGET_CHECK(defaultHookMgr == NULL, return defaultHookMgr);
64
65    BEGET_CHECK(autoCreate, return NULL);
66
67    // Create default HOOK_MGR if not created
68    defaultHookMgr = HookMgrCreate("default");
69    return defaultHookMgr;
70}
71
72static int hookStageCompare(ListNode *node, void *data)
73{
74    const HOOK_STAGE *stage;
75    int compareStage = *((int *)data);
76
77    stage = (const HOOK_STAGE *)node;
78    return (stage->stage - compareStage);
79}
80
81static void hookStageDestroy(ListNode *node)
82{
83    HOOK_STAGE *stage;
84
85    BEGET_CHECK(node != NULL, return);
86
87    stage = (HOOK_STAGE *)node;
88    OH_ListRemoveAll(&(stage->hooks), NULL);
89    free((void *)stage);
90}
91
92// Get HOOK_STAGE if found, otherwise create it
93static HOOK_STAGE *getHookStage(HOOK_MGR *hookMgr, int stage, int createIfNotFound)
94{
95    HOOK_STAGE *stageItem;
96
97    stageItem = (HOOK_STAGE *)OH_ListFind(&(hookMgr->stages), (void *)(&stage), hookStageCompare);
98    BEGET_CHECK(stageItem == NULL, return stageItem);
99
100    BEGET_CHECK(createIfNotFound, return NULL);
101
102    // Not found, create it
103    stageItem = (HOOK_STAGE *)malloc(sizeof(HOOK_STAGE));
104    BEGET_CHECK(stageItem != NULL, return NULL);
105    stageItem->stage = stage;
106    OH_ListInit(&(stageItem->hooks));
107    OH_ListAddTail(&(hookMgr->stages), (ListNode *)stageItem);
108    return stageItem;
109}
110
111static int hookItemCompare(ListNode *node, ListNode *newNode)
112{
113    const HOOK_ITEM *hookItem;
114    const HOOK_ITEM *newItem;
115
116    hookItem = (const HOOK_ITEM *)node;
117    newItem = (const HOOK_ITEM *)newNode;
118    return (hookItem->info.prio - newItem->info.prio);
119}
120
121struct HOOKITEM_COMPARE_VAL {
122    int prio;
123    OhosHook hook;
124    void *hookCookie;
125};
126
127static int hookItemCompareValue(ListNode *node, void *data)
128{
129    const HOOK_ITEM *hookItem;
130    struct HOOKITEM_COMPARE_VAL *compareVal = (struct HOOKITEM_COMPARE_VAL *)data;
131
132    hookItem = (const HOOK_ITEM *)node;
133    BEGET_CHECK(hookItem->info.prio == compareVal->prio, return (hookItem->info.prio - compareVal->prio));
134    if (hookItem->info.hook == compareVal->hook && hookItem->info.hookCookie == compareVal->hookCookie) {
135        return 0;
136    }
137    return -1;
138}
139
140// Add hook to stage list with prio ordered
141static int addHookToStage(HOOK_STAGE *hookStage, int prio, OhosHook hook, void *hookCookie)
142{
143    HOOK_ITEM *hookItem;
144    struct HOOKITEM_COMPARE_VAL compareVal;
145
146    // Check if exists
147    compareVal.prio = prio;
148    compareVal.hook = hook;
149    compareVal.hookCookie = hookCookie;
150    hookItem = (HOOK_ITEM *)OH_ListFind(&(hookStage->hooks), (void *)(&compareVal), hookItemCompareValue);
151    BEGET_CHECK(hookItem == NULL, return 0);
152
153    // Create new item
154    hookItem = (HOOK_ITEM *)malloc(sizeof(HOOK_ITEM));
155    BEGET_CHECK(hookItem != NULL, return -1);
156    hookItem->info.stage = hookStage->stage;
157    hookItem->info.prio = prio;
158    hookItem->info.hook = hook;
159    hookItem->info.hookCookie = hookCookie;
160    hookItem->stage = hookStage;
161
162    // Insert with order
163    OH_ListAddWithOrder(&(hookStage->hooks), (ListNode *)hookItem, hookItemCompare);
164    return 0;
165}
166
167int HookMgrAddEx(HOOK_MGR *hookMgr, const HOOK_INFO *hookInfo)
168{
169    HOOK_STAGE *stageItem;
170    BEGET_CHECK(hookInfo != NULL, return -1);
171    BEGET_CHECK(hookInfo->hook != NULL, return -1);
172
173    // Get HOOK_MGR
174    hookMgr = getHookMgr(hookMgr, true);
175    BEGET_CHECK(hookMgr != NULL, return -1);
176
177    // Get HOOK_STAGE list
178    stageItem = getHookStage(hookMgr, hookInfo->stage, true);
179    BEGET_CHECK(stageItem != NULL, return -1);
180
181    // Add hook to stage
182    return addHookToStage(stageItem, hookInfo->prio, hookInfo->hook, hookInfo->hookCookie);
183}
184
185int HookMgrAdd(HOOK_MGR *hookMgr, int stage, int prio, OhosHook hook)
186{
187    HOOK_INFO info;
188    info.stage = stage;
189    info.prio = prio;
190    info.hook = hook;
191    info.hookCookie = NULL;
192    return HookMgrAddEx(hookMgr, &info);
193}
194
195static int hookTraversalDelProc(ListNode *node, void *cookie)
196{
197    HOOK_ITEM *hookItem = (HOOK_ITEM *)node;
198
199    // Not equal, just return
200    BEGET_CHECK((void *)hookItem->info.hook == cookie, return 0);
201
202    // Remove from the list
203    OH_ListRemove(node);
204    // Destroy myself
205    free((void *)node);
206
207    return 0;
208}
209
210/*
211 * 删除钩子函数
212 * hook为NULL,表示删除该stage上的所有hooks
213 */
214void HookMgrDel(HOOK_MGR *hookMgr, int stage, OhosHook hook)
215{
216    HOOK_STAGE *stageItem;
217
218    // Get HOOK_MGR
219    hookMgr = getHookMgr(hookMgr, 0);
220    BEGET_CHECK(hookMgr != NULL, return);
221
222    // Get HOOK_STAGE list
223    stageItem = getHookStage(hookMgr, stage, false);
224    BEGET_CHECK(stageItem != NULL, return);
225
226    if (hook != NULL) {
227        OH_ListTraversal(&(stageItem->hooks), hook, hookTraversalDelProc, 0);
228        return;
229    }
230
231    // Remove from list
232    OH_ListRemove((ListNode *)stageItem);
233
234    // Destroy stage item
235    hookStageDestroy((ListNode *)stageItem);
236}
237
238typedef struct tagHOOK_EXECUTION_ARGS {
239    void *executionContext;
240    const HOOK_EXEC_OPTIONS *options;
241} HOOK_EXECUTION_ARGS;
242
243static int hookExecutionProc(ListNode *node, void *cookie)
244{
245    int ret;
246    HOOK_ITEM *hookItem = (HOOK_ITEM *)node;
247    HOOK_EXECUTION_ARGS *args = (HOOK_EXECUTION_ARGS *)cookie;
248
249    if ((args->options != NULL) && (args->options->preHook != NULL)) {
250        args->options->preHook(&hookItem->info, args->executionContext);
251    }
252    ret = hookItem->info.hook(&hookItem->info, args->executionContext);
253    if ((args->options != NULL) && (args->options->postHook != NULL)) {
254        args->options->postHook(&hookItem->info, args->executionContext, ret);
255    }
256
257    return ret;
258}
259
260/*
261 * 执行钩子函数
262 */
263int HookMgrExecute(HOOK_MGR *hookMgr, int stage, void *executionContext, const HOOK_EXEC_OPTIONS *options)
264{
265    unsigned int flags;
266    HOOK_STAGE *stageItem;
267    HOOK_EXECUTION_ARGS args;
268
269    // Get HOOK_MGR
270    hookMgr = getHookMgr(hookMgr, 0);
271    BEGET_CHECK(hookMgr != NULL, return -1)
272
273    // Get HOOK_STAGE list
274    stageItem = getHookStage(hookMgr, stage, false);
275    BEGET_CHECK(stageItem != NULL, return ERR_NO_HOOK_STAGE);
276
277    flags = 0;
278    if (options != NULL) {
279        flags = (unsigned int)(options->flags);
280    }
281
282    args.executionContext = executionContext;
283    args.options = options;
284
285    // Traversal all hooks in the specified stage
286    return OH_ListTraversal(&(stageItem->hooks), (void *)(&args), hookExecutionProc, flags);
287}
288
289HOOK_MGR *HookMgrCreate(const char *name)
290{
291    HOOK_MGR *ret;
292
293    BEGET_CHECK(name != NULL, return NULL);
294    ret = (HOOK_MGR *)malloc(sizeof(HOOK_MGR));
295    BEGET_CHECK(ret != NULL, return NULL);
296
297    ret->name = strdup(name);
298    if (ret->name == NULL) {
299        free((void *)ret);
300        return NULL;
301    }
302    OH_ListInit(&(ret->stages));
303    return ret;
304}
305
306void HookMgrDestroy(HOOK_MGR *hookMgr)
307{
308    hookMgr = getHookMgr(hookMgr, 0);
309    BEGET_CHECK(hookMgr != NULL, return);
310
311    OH_ListRemoveAll(&(hookMgr->stages), hookStageDestroy);
312
313    if (hookMgr == defaultHookMgr) {
314        defaultHookMgr = NULL;
315    }
316    if (hookMgr->name != NULL) {
317        free((void *)hookMgr->name);
318    }
319    free((void *)hookMgr);
320}
321
322typedef struct tagHOOK_TRAVERSAL_ARGS {
323    void *traversalCookie;
324    OhosHookTraversal traversal;
325} HOOK_TRAVERSAL_ARGS;
326
327static int hookItemTraversal(ListNode *node, void *data)
328{
329    HOOK_ITEM *hookItem;
330    HOOK_TRAVERSAL_ARGS *stageArgs;
331
332    hookItem = (HOOK_ITEM *)node;
333    stageArgs = (HOOK_TRAVERSAL_ARGS *)data;
334
335    stageArgs->traversal(&(hookItem->info), stageArgs->traversalCookie);
336    return 0;
337}
338
339static int hookStageTraversal(ListNode *node, void *data)
340{
341    HOOK_STAGE *stageItem = (HOOK_STAGE *)node;
342    OH_ListTraversal(&(stageItem->hooks), data, hookItemTraversal, 0);
343    return 0;
344}
345
346/*
347 * 遍历所有的hooks
348 */
349void HookMgrTraversal(HOOK_MGR *hookMgr, void *traversalCookie, OhosHookTraversal traversal)
350{
351    HOOK_TRAVERSAL_ARGS stageArgs;
352
353    BEGET_CHECK(traversal != NULL, return);
354
355    hookMgr = getHookMgr(hookMgr, 0);
356    BEGET_CHECK(hookMgr != NULL, return);
357
358    // Prepare common args
359    stageArgs.traversalCookie = traversalCookie;
360    stageArgs.traversal = traversal;
361    OH_ListTraversal(&(hookMgr->stages), (void *)(&stageArgs), hookStageTraversal, 0);
362}
363
364/*
365 * 获取指定stage中hooks的个数
366 */
367int HookMgrGetHooksCnt(HOOK_MGR *hookMgr, int stage)
368{
369    HOOK_STAGE *stageItem;
370
371    hookMgr = getHookMgr(hookMgr, 0);
372    BEGET_CHECK(hookMgr != NULL, return 0);
373
374    // Get HOOK_STAGE list
375    stageItem = getHookStage(hookMgr, stage, false);
376    BEGET_CHECK(stageItem != NULL, return 0);
377
378    return OH_ListGetCnt(&(stageItem->hooks));
379}
380
381/*
382 * 获取指定stage中hooks的个数
383 */
384int HookMgrGetStagesCnt(HOOK_MGR *hookMgr)
385{
386    hookMgr = getHookMgr(hookMgr, 0);
387    BEGET_CHECK(hookMgr != NULL, return 0);
388
389    return OH_ListGetCnt(&(hookMgr->stages));
390}
391