1/* 2* Copyright (c) 2021 Huawei Device Co., Ltd. 3* Licensed under the Apache License, Version 2.0 (the "License"); 4* you may not use this file except in compliance with the License. 5* You may obtain a copy of the License at 6* 7* http://www.apache.org/licenses/LICENSE-2.0 8* 9* Unless required by applicable law or agreed to in writing, software 10* distributed under the License is distributed on an "AS IS" BASIS, 11* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12* See the License for the specific language governing permissions and 13* limitations under the License. 14*/ 15 16#include "fd_holder_internal.h" 17#include <errno.h> 18#include <stdio.h> 19#include "beget_ext.h" 20#include "securec.h" 21 22#ifndef PAGE_SIZE 23#define PAGE_SIZE (4096U) 24#endif 25 26int BuildControlMessage(struct msghdr *msghdr, int *fds, int fdCount, bool sendUcred) 27{ 28 if (msghdr == NULL || (fdCount > 0 && fds == NULL)) { 29 BEGET_LOGE("Build control message with invalid parameter"); 30 return -1; 31 } 32 33 if (fdCount > 0) { 34 msghdr->msg_controllen = CMSG_SPACE(sizeof(int) * fdCount); 35 } else { 36 msghdr->msg_controllen = 0; 37 } 38 39 if (sendUcred) { 40 msghdr->msg_controllen += CMSG_SPACE(sizeof(struct ucred)); 41 } 42 43 msghdr->msg_control = calloc(1, ((msghdr->msg_controllen == 0) ? 1 : msghdr->msg_controllen)); 44 BEGET_ERROR_CHECK(msghdr->msg_control != NULL, return -1, "Failed to build control message"); 45 46 struct cmsghdr *cmsg = NULL; 47 cmsg = CMSG_FIRSTHDR(msghdr); 48 BEGET_ERROR_CHECK(cmsg != NULL, return -1, "Failed to build cmsg"); 49 50 if (fdCount > 0) { 51 cmsg->cmsg_level = SOL_SOCKET; 52 cmsg->cmsg_type = SCM_RIGHTS; 53 cmsg->cmsg_len = CMSG_LEN(sizeof(int) * fdCount); 54 int ret = memcpy_s(CMSG_DATA(cmsg), cmsg->cmsg_len, fds, sizeof(int) * fdCount); 55 BEGET_ERROR_CHECK(ret == 0, free(msghdr->msg_control); 56 msghdr->msg_control = NULL; 57 return -1, "Control message is not valid"); 58 // build ucred info 59 cmsg = CMSG_NXTHDR(msghdr, cmsg); 60 } 61 62 if (sendUcred) { 63 BEGET_ERROR_CHECK(cmsg != NULL, free(msghdr->msg_control); 64 msghdr->msg_control = NULL; 65 return -1, "Control message is not valid"); 66 67 struct ucred *ucred; 68 cmsg->cmsg_level = SOL_SOCKET; 69 cmsg->cmsg_type = SCM_CREDENTIALS; 70 cmsg->cmsg_len = CMSG_LEN(sizeof(struct ucred)); 71 ucred = (struct ucred*) CMSG_DATA(cmsg); 72 ucred->pid = getpid(); 73 ucred->uid = getuid(); 74 ucred->gid = getgid(); 75 } 76 return 0; 77} 78 79STATIC int *GetFdsFromMsg(size_t *outFdCount, pid_t *requestPid, struct msghdr msghdr) 80{ 81 if ((msghdr.msg_flags) & MSG_TRUNC) { 82 BEGET_LOGE("Message was truncated when receiving fds"); 83 return NULL; 84 } 85 86 struct cmsghdr *cmsg = NULL; 87 int *fds = NULL; 88 size_t fdCount = 0; 89 for (cmsg = CMSG_FIRSTHDR(&msghdr); cmsg != NULL; cmsg = CMSG_NXTHDR(&msghdr, cmsg)) { 90 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { 91 fds = (int*)CMSG_DATA(cmsg); 92 fdCount = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); 93 BEGET_ERROR_CHECK(fdCount <= MAX_HOLD_FDS, return NULL, "Too many fds returned."); 94 } 95 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS && 96 cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) { 97 // Ignore credentials 98 if (requestPid != NULL) { 99 struct ucred *ucred = (struct ucred*)CMSG_DATA(cmsg); 100 *requestPid = ucred->pid; 101 } 102 continue; 103 } 104 } 105 int *outFds = NULL; 106 if (fds != NULL && fdCount > 0) { 107 outFds = calloc(fdCount + 1, sizeof(int)); 108 BEGET_ERROR_CHECK(outFds != NULL, return NULL, "Failed to allocate memory for fds"); 109 BEGET_ERROR_CHECK(memcpy_s(outFds, sizeof(int) * (fdCount + 1), fds, sizeof(int) * fdCount) == 0, 110 free(outFds); return NULL, "Failed to copy fds"); 111 } 112 *outFdCount = fdCount; 113 return outFds; 114} 115 116// This function will allocate memory to store FDs 117// Remember to delete when not used anymore. 118int *ReceiveFds(int sock, struct iovec iovec, size_t *outFdCount, bool nonblock, pid_t *requestPid) 119{ 120 CMSG_BUFFER_TYPE(CMSG_SPACE(sizeof(struct ucred)) + 121 CMSG_SPACE(sizeof(int) * MAX_HOLD_FDS)) control; 122 123 BEGET_ERROR_CHECK(sizeof(control) <= PAGE_SIZE, return NULL, "Too many fds, out of memory"); 124 125 struct msghdr msghdr = { 126 .msg_iov = &iovec, 127 .msg_iovlen = 1, 128 .msg_control = &control, 129 .msg_controllen = sizeof(control), 130 .msg_flags = 0, 131 }; 132 133 int flags = MSG_CMSG_CLOEXEC | MSG_TRUNC; 134 if (nonblock) { 135 flags |= MSG_DONTWAIT; 136 } 137 ssize_t rc = TEMP_FAILURE_RETRY(recvmsg(sock, &msghdr, flags)); 138 BEGET_ERROR_CHECK(rc >= 0, return NULL, "Failed to get fds from remote, err = %d", errno); 139 return GetFdsFromMsg(outFdCount, requestPid, msghdr); 140}