1/*
2* Copyright (c) 2021 Huawei Device Co., Ltd.
3* Licensed under the Apache License, Version 2.0 (the "License");
4* you may not use this file except in compliance with the License.
5* You may obtain a copy of the License at
6*
7*     http://www.apache.org/licenses/LICENSE-2.0
8*
9* Unless required by applicable law or agreed to in writing, software
10* distributed under the License is distributed on an "AS IS" BASIS,
11* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12* See the License for the specific language governing permissions and
13* limitations under the License.
14*/
15
16#include "fd_holder_internal.h"
17#include <errno.h>
18#include <stdio.h>
19#include "beget_ext.h"
20#include "securec.h"
21
22#ifndef PAGE_SIZE
23#define PAGE_SIZE (4096U)
24#endif
25
26int BuildControlMessage(struct msghdr *msghdr,  int *fds, int fdCount, bool sendUcred)
27{
28    if (msghdr == NULL || (fdCount > 0 && fds == NULL)) {
29        BEGET_LOGE("Build control message with invalid parameter");
30        return -1;
31    }
32
33    if (fdCount > 0) {
34        msghdr->msg_controllen = CMSG_SPACE(sizeof(int) * fdCount);
35    } else {
36        msghdr->msg_controllen = 0;
37    }
38
39    if (sendUcred) {
40        msghdr->msg_controllen += CMSG_SPACE(sizeof(struct ucred));
41    }
42
43    msghdr->msg_control = calloc(1, ((msghdr->msg_controllen == 0) ? 1 : msghdr->msg_controllen));
44    BEGET_ERROR_CHECK(msghdr->msg_control != NULL, return -1, "Failed to build control message");
45
46    struct cmsghdr *cmsg = NULL;
47    cmsg = CMSG_FIRSTHDR(msghdr);
48    BEGET_ERROR_CHECK(cmsg != NULL, return -1, "Failed to build cmsg");
49
50    if (fdCount > 0) {
51        cmsg->cmsg_level = SOL_SOCKET;
52        cmsg->cmsg_type = SCM_RIGHTS;
53        cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fdCount);
54        int ret = memcpy_s(CMSG_DATA(cmsg), cmsg->cmsg_len, fds, sizeof(int) * fdCount);
55        BEGET_ERROR_CHECK(ret == 0, free(msghdr->msg_control);
56            msghdr->msg_control = NULL;
57            return -1, "Control message is not valid");
58        // build ucred info
59        cmsg = CMSG_NXTHDR(msghdr, cmsg);
60    }
61
62    if (sendUcred) {
63        BEGET_ERROR_CHECK(cmsg != NULL, free(msghdr->msg_control);
64            msghdr->msg_control = NULL;
65            return -1, "Control message is not valid");
66
67        struct ucred *ucred;
68        cmsg->cmsg_level = SOL_SOCKET;
69        cmsg->cmsg_type = SCM_CREDENTIALS;
70        cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred));
71        ucred = (struct ucred*) CMSG_DATA(cmsg);
72        ucred->pid = getpid();
73        ucred->uid = getuid();
74        ucred->gid = getgid();
75    }
76    return 0;
77}
78
79STATIC int *GetFdsFromMsg(size_t *outFdCount, pid_t *requestPid, struct msghdr msghdr)
80{
81    if ((msghdr.msg_flags) & MSG_TRUNC) {
82        BEGET_LOGE("Message was truncated when receiving fds");
83        return NULL;
84    }
85
86    struct cmsghdr *cmsg = NULL;
87    int *fds = NULL;
88    size_t fdCount = 0;
89    for (cmsg = CMSG_FIRSTHDR(&msghdr); cmsg != NULL; cmsg = CMSG_NXTHDR(&msghdr, cmsg)) {
90        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
91            fds = (int*)CMSG_DATA(cmsg);
92            fdCount = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
93            BEGET_ERROR_CHECK(fdCount <= MAX_HOLD_FDS, return NULL, "Too many fds returned.");
94        }
95        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
96            cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
97            // Ignore credentials
98            if (requestPid != NULL) {
99                struct ucred *ucred = (struct ucred*)CMSG_DATA(cmsg);
100                *requestPid = ucred->pid;
101            }
102            continue;
103        }
104    }
105    int *outFds = NULL;
106    if (fds != NULL && fdCount > 0) {
107        outFds = calloc(fdCount + 1, sizeof(int));
108        BEGET_ERROR_CHECK(outFds != NULL, return NULL, "Failed to allocate memory for fds");
109        BEGET_ERROR_CHECK(memcpy_s(outFds, sizeof(int) * (fdCount + 1), fds, sizeof(int) * fdCount) == 0,
110            free(outFds); return NULL, "Failed to copy fds");
111    }
112    *outFdCount = fdCount;
113    return outFds;
114}
115
116// This function will allocate memory to store FDs
117// Remember to delete when not used anymore.
118int *ReceiveFds(int sock, struct iovec iovec, size_t *outFdCount, bool nonblock, pid_t *requestPid)
119{
120    CMSG_BUFFER_TYPE(CMSG_SPACE(sizeof(struct ucred)) +
121         CMSG_SPACE(sizeof(int) * MAX_HOLD_FDS)) control;
122
123    BEGET_ERROR_CHECK(sizeof(control) <= PAGE_SIZE, return NULL, "Too many fds, out of memory");
124
125    struct msghdr msghdr = {
126        .msg_iov = &iovec,
127        .msg_iovlen = 1,
128        .msg_control = &control,
129        .msg_controllen = sizeof(control),
130        .msg_flags = 0,
131    };
132
133    int flags = MSG_CMSG_CLOEXEC | MSG_TRUNC;
134    if (nonblock) {
135        flags |= MSG_DONTWAIT;
136    }
137    ssize_t rc = TEMP_FAILURE_RETRY(recvmsg(sock, &msghdr, flags));
138    BEGET_ERROR_CHECK(rc >= 0, return NULL, "Failed to get fds from remote, err = %d", errno);
139    return GetFdsFromMsg(outFdCount, requestPid, msghdr);
140}