1/*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <fcntl.h>
17#include <unistd.h>
18#include <sys/types.h>
19#include <sys/socket.h>
20
21#include "beget_ext.h"
22#include "control_fd.h"
23#include "init_utils.h"
24#include "securec.h"
25
26static CmdService g_cmdService;
27static LoopHandle g_controlFdLoop = NULL;
28
29CallbackControlFdProcess g_controlFdFunc = NULL;
30
31static void OnClose(const TaskHandle task)
32{
33    CmdTask *agent = (CmdTask *)LE_GetUserData(task);
34    BEGET_ERROR_CHECK(agent != NULL, return, "[control_fd] Can not get agent");
35    OH_ListRemove(&agent->item);
36    OH_ListInit(&agent->item);
37}
38
39CONTROL_FD_STATIC int CheckSocketPermission(const TaskHandle task)
40{
41    struct ucred uc = {-1, -1, -1};
42    socklen_t len = sizeof(uc);
43    if (getsockopt(LE_GetSocketFd(task), SOL_SOCKET, SO_PEERCRED, &uc, &len) < 0) {
44        BEGET_LOGE("Failed to get socket option. err = %d", errno);
45        return -1;
46    }
47    // Only root is permitted to use control fd of init.
48    if (uc.uid != 0) { // non-root user
49        errno = EPERM;
50        return -1;
51    }
52    return 0;
53}
54
55CONTROL_FD_STATIC void CmdOnRecvMessage(const TaskHandle task, const uint8_t *buffer, uint32_t buffLen)
56{
57    if (buffer == NULL) {
58        return;
59    }
60    CmdTask *agent = (CmdTask *)LE_GetUserData(task);
61    BEGET_ERROR_CHECK(agent != NULL, return, "[control_fd] Can not get agent");
62
63    // parse msg to exec
64    CmdMessage *msg = (CmdMessage *)buffer;
65    if ((msg->type >= ACTION_MAX) || (msg->cmd[0] == '\0') || (msg->ptyName[0] == '\0')) {
66        BEGET_LOGE("[control_fd] Received msg is invalid");
67        return;
68    }
69
70    BEGET_ERROR_CHECK(CheckSocketPermission(task) >= 0, return, "Check socket permission failed, err = %d", errno);
71
72#ifndef STARTUP_INIT_TEST
73    agent->pid = fork();
74    if (agent->pid == 0) {
75        OpenConsole();
76        char *realPath = GetRealPath(msg->ptyName);
77        BEGET_ERROR_CHECK(realPath != NULL, _exit(1), "Failed get realpath, err=%d", errno);
78        int n = strncmp(realPath, "/dev/pts/", strlen("/dev/pts/"));
79        BEGET_ERROR_CHECK(n == 0, free(realPath); _exit(1), "pts path %s is invaild", realPath);
80        int fd = open(realPath, O_RDWR);
81        free(realPath);
82        BEGET_ERROR_CHECK(fd >= 0, _exit(1), "Failed open %s, err=%d", msg->ptyName, errno);
83        (void)dup2(fd, STDIN_FILENO);
84        (void)dup2(fd, STDOUT_FILENO);
85        (void)dup2(fd, STDERR_FILENO); // Redirect fd to 0, 1, 2
86        (void)close(fd);
87        if (g_controlFdFunc != NULL) {
88            g_controlFdFunc(msg->type, msg->cmd, NULL);
89        }
90        _exit(0);
91    } else if (agent->pid < 0) {
92        BEGET_LOGE("[control_fd] Failed to fork child process, err = %d", errno);
93    }
94#endif
95    return;
96}
97
98CONTROL_FD_STATIC int SendMessage(LoopHandle loop, TaskHandle task, const char *message)
99{
100    if (message == NULL) {
101        BEGET_LOGE("[control_fd] Invalid parameter");
102        return -1;
103    }
104    BufferHandle handle = NULL;
105    uint32_t bufferSize = strlen(message) + 1;
106    handle = LE_CreateBuffer(loop, bufferSize);
107    char *buff = (char *)LE_GetBufferInfo(handle, NULL, &bufferSize);
108    BEGET_ERROR_CHECK(buff != NULL, return -1, "[control_fd] Failed get buffer info");
109    int ret = memcpy_s(buff, bufferSize, message, strlen(message) + 1);
110    BEGET_ERROR_CHECK(ret == 0, LE_FreeBuffer(g_controlFdLoop, task, handle);
111        return -1, "[control_fd] Failed memcpy_s err=%d", errno);
112    LE_STATUS status = LE_Send(loop, task, handle, strlen(message) + 1);
113    BEGET_ERROR_CHECK(status == LE_SUCCESS, return -1, "[control_fd] Failed le send msg");
114    return 0;
115}
116
117CONTROL_FD_STATIC int CmdOnIncommingConnect(const LoopHandle loop, const TaskHandle server)
118{
119    TaskHandle client = NULL;
120    LE_StreamInfo info = {};
121#ifndef STARTUP_INIT_TEST
122    info.baseInfo.flags = TASK_STREAM | TASK_PIPE | TASK_CONNECT;
123#else
124    info.baseInfo.flags = TASK_STREAM | TASK_PIPE | TASK_CONNECT | TASK_TEST;
125#endif
126    info.baseInfo.close = OnClose;
127    info.baseInfo.userDataSize = sizeof(CmdTask);
128    info.disConnectComplete = NULL;
129    info.sendMessageComplete = NULL;
130    info.recvMessage = CmdOnRecvMessage;
131    int ret = LE_AcceptStreamClient(g_controlFdLoop, server, &client, &info);
132    BEGET_ERROR_CHECK(ret == 0, return -1, "[control_fd] Failed accept stream")
133    CmdTask *agent = (CmdTask *)LE_GetUserData(client);
134    BEGET_ERROR_CHECK(agent != NULL, return -1, "[control_fd] Invalid agent");
135    agent->task = client;
136    OH_ListInit(&agent->item);
137    ret = SendMessage(g_controlFdLoop, agent->task, "connect success.");
138    BEGET_ERROR_CHECK(ret == 0, return -1, "[control_fd] Failed send msg");
139    OH_ListAddTail(&g_cmdService.head, &agent->item);
140    return 0;
141}
142
143void CmdServiceInit(const char *socketPath, CallbackControlFdProcess func, LoopHandle loop)
144{
145    if ((socketPath == NULL) || (func == NULL) || (loop == NULL)) {
146        BEGET_LOGE("[control_fd] Invalid parameter");
147        return;
148    }
149    OH_ListInit(&g_cmdService.head);
150    LE_StreamServerInfo info = {};
151    info.baseInfo.flags = TASK_STREAM | TASK_SERVER | TASK_PIPE;
152    info.server = (char *)socketPath;
153    info.socketId = -1;
154    info.baseInfo.close = NULL;
155    info.disConnectComplete = NULL;
156    info.incommingConnect = CmdOnIncommingConnect;
157    info.sendMessageComplete = NULL;
158    info.recvMessage = NULL;
159    g_controlFdFunc = func;
160    if (g_controlFdLoop == NULL) {
161        g_controlFdLoop = loop;
162    }
163    (void)LE_CreateStreamServer(g_controlFdLoop, &g_cmdService.serverTask, &info);
164}
165
166static int ClientTraversalProc(ListNode *node, void *data)
167{
168    CmdTask *info = ListEntry(node, CmdTask, item);
169    int pid = *(int *)data;
170    return pid - info->pid;
171}
172
173void CmdServiceProcessDelClient(pid_t pid)
174{
175    ListNode *node = OH_ListFind(&g_cmdService.head, (void *)&pid, ClientTraversalProc);
176    if (node != NULL) {
177        CmdTask *agent = ListEntry(node, CmdTask, item);
178        OH_ListRemove(&agent->item);
179        OH_ListInit(&agent->item);
180        LE_CloseTask(g_controlFdLoop, agent->task);
181    }
182}
183
184static void CmdServiceDestroyProc(ListNode *node)
185{
186    if (node == NULL) {
187        return;
188    }
189    CmdTask *agent = ListEntry(node, CmdTask, item);
190    OH_ListRemove(&agent->item);
191    OH_ListInit(&agent->item);
192    LE_CloseTask(g_controlFdLoop, agent->task);
193}
194
195void CmdServiceProcessDestroyClient(void)
196{
197    OH_ListRemoveAll(&g_cmdService.head, CmdServiceDestroyProc);
198    LE_StopLoop(g_controlFdLoop);
199}
200