1/*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "local_socketpair.h"
17#include "hilog/log.h"
18
19#include <sys/types.h>
20#include <sys/socket.h>
21#include <unistd.h>
22#include <fcntl.h>
23#include <cerrno>
24#include <scoped_bytrace.h>
25
26namespace OHOS {
27using namespace OHOS::HiviewDFX;
28#undef LOG_DOMAIN
29#define LOG_DOMAIN 0xD001400
30
31#undef LOG_TAG
32#define LOG_TAG "LocalSocketPair"
33#define LOGD(fmt, ...) HILOG_DEBUG(LOG_CORE, fmt, ##__VA_ARGS__)
34#define LOGW(fmt, ...) HILOG_WARN(LOG_CORE, fmt, ##__VA_ARGS__)
35#define LOGE(fmt, ...) HILOG_ERROR(LOG_CORE, fmt, ##__VA_ARGS__)
36namespace {
37constexpr int32_t SOCKET_PAIR_SIZE = 2;
38constexpr int32_t INVALID_FD = -1;
39constexpr int32_t ERRNO_EAGAIN = -1;
40constexpr int32_t ERRNO_OTHER = -2;
41constexpr int32_t LEAK_FD_CNT = 200;
42static int32_t g_fdCnt = 0;
43}  // namespace
44
45LocalSocketPair::LocalSocketPair()
46    : sendFd_(INVALID_FD), receiveFd_(INVALID_FD)
47{
48}
49
50LocalSocketPair::~LocalSocketPair()
51{
52    LOGD("%{public}s close socketpair, sendFd : %{public}d, receiveFd : %{public}d", __func__, sendFd_, receiveFd_);
53    if ((sendFd_ != INVALID_FD) || (receiveFd_ != INVALID_FD)) {
54        g_fdCnt--;
55    }
56    CloseFd(sendFd_);
57    CloseFd(receiveFd_);
58}
59
60int32_t LocalSocketPair::SetSockopt(size_t sendSize, size_t receiveSize, int32_t* socketPair, int32_t socketPairSize)
61{
62    for (int i = 0; i < socketPairSize; ++i) {
63        int32_t ret = setsockopt(socketPair[i], SOL_SOCKET, SO_SNDBUF, &sendSize, sizeof(sendSize));
64        if (ret != 0) {
65            CloseFd(socketPair[0]);
66            CloseFd(socketPair[1]);
67            LOGE("%{public}s setsockopt socketpair %{public}d sendbuffer size failed", __func__, i);
68            return -1;
69        }
70        ret = setsockopt(socketPair[i], SOL_SOCKET, SO_RCVBUF, &receiveSize, sizeof(receiveSize));
71        if (ret != 0) {
72            CloseFd(socketPair[0]);
73            CloseFd(socketPair[1]);
74            LOGE("%{public}s setsockopt socketpair %{public}d receivebuffer size failed", __func__, i);
75            return -1;
76        }
77        ret = fcntl(socketPair[i], F_SETFL, O_NONBLOCK);
78        if (ret != 0) {
79            CloseFd(socketPair[0]);
80            CloseFd(socketPair[1]);
81            LOGE("%{public}s fcntl socketpair %{public}d nonblock failed", __func__, i);
82            return -1;
83        }
84    }
85    return 0;
86}
87
88int32_t LocalSocketPair::CreateChannel(size_t sendSize, size_t receiveSize)
89{
90    if ((sendFd_ != INVALID_FD) || (receiveFd_ != INVALID_FD)) {
91        LOGD("%{public}s already create socketpair", __func__);
92        return 0;
93    }
94
95    int32_t socketPair[SOCKET_PAIR_SIZE] = { 0 };
96    if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, socketPair) != 0) {
97        ScopedBytrace func("Create socketpair failed, errno = " + std::to_string(errno));
98        LOGE("%{public}s create socketpair failed", __func__);
99        return -1;
100    }
101    if (socketPair[0] == 0 || socketPair[1] == 0) {
102        int32_t unusedFds[SOCKET_PAIR_SIZE] = {socketPair[0], socketPair[1]};
103        int32_t err = socketpair(AF_UNIX, SOCK_SEQPACKET, 0, socketPair);
104        CloseFd(unusedFds[0]);
105        CloseFd(unusedFds[1]);
106        if (err != 0) {
107            ScopedBytrace func2("Create socketpair failed for the second time, errno = " + std::to_string(errno));
108            LOGE("%{public}s create socketpair failed", __func__);
109            return -1;
110        }
111    }
112
113    // set socket attr
114    int32_t ret = SetSockopt(sendSize, receiveSize, socketPair, SOCKET_PAIR_SIZE);
115    if (ret != 0) {
116        return ret;
117    }
118    sendFd_ = socketPair[0];
119    receiveFd_ = socketPair[1];
120    if ((sendFd_ <= 0) || (receiveFd_ <= 0)) {
121        LOGE("%{public}s socketpair invalid fd, sendFd_:%{public}d, receiveFd_:%{public}d",
122            __func__, sendFd_, receiveFd_);
123        return -1;
124    }
125    LOGD("%{public}s create socketpair success, receiveFd_ : %{public}d, sendFd_ : %{public}d", __func__, receiveFd_,
126        sendFd_);
127    g_fdCnt++;
128    if (g_fdCnt > LEAK_FD_CNT) {
129        LOGW("%{public}s fdCnt: %{public}d", __func__, g_fdCnt);
130    }
131
132    return 0;
133}
134
135int32_t LocalSocketPair::SendData(const void *vaddr, size_t size)
136{
137    if (vaddr == nullptr || sendFd_ < 0) {
138        LOGE("%{public}s failed, param is invalid", __func__);
139        return -1;
140    }
141    ssize_t length = TEMP_FAILURE_RETRY(send(sendFd_, vaddr, size, MSG_DONTWAIT | MSG_NOSIGNAL));
142    if (length < 0) {
143        int errnoRecord = errno;
144        ScopedBytrace func("SocketPair SendData failed, errno = " + std::to_string(errnoRecord) +
145                            ", sendFd_ = " + std::to_string(sendFd_) + ", receiveFd_ = " + std::to_string(receiveFd_) +
146                            ", length = " + std::to_string(length));
147        LOGD("%{public}s send failed:%{public}d, length = %{public}d", __func__, errnoRecord,
148            static_cast<int32_t>(length));
149        if (errnoRecord == EAGAIN) {
150            return ERRNO_EAGAIN;
151        } else {
152            return ERRNO_OTHER;
153        }
154    }
155    return length;
156}
157
158int32_t LocalSocketPair::ReceiveData(void *vaddr, size_t size)
159{
160    if (vaddr == nullptr || (receiveFd_ < 0)) {
161        LOGE("%{public}s failed, vaddr is null or receiveFd_ invalid", __func__);
162        return -1;
163    }
164    ssize_t length;
165    do {
166        length = recv(receiveFd_, vaddr, size, MSG_DONTWAIT);
167    } while (errno == EINTR);
168    if (length < 0) {
169        ScopedBytrace func("SocketPair ReceiveData failed errno = " + std::to_string(errno) +
170                            ", sendFd_ = " + std::to_string(sendFd_) + ", receiveFd_ = " + std::to_string(receiveFd_) +
171                            ", length = " + std::to_string(length));
172        return -1;
173    }
174    return length;
175}
176
177// internal interface
178int32_t LocalSocketPair::SendFdToBinder(MessageParcel &data, int32_t &fd)
179{
180    if (fd < 0) {
181        return -1;
182    }
183    // need dup???
184    bool result = data.WriteFileDescriptor(fd);
185    if (!result) {
186        return -1;
187    }
188    return 0;
189}
190
191int32_t LocalSocketPair::SendToBinder(MessageParcel &data)
192{
193    return SendFdToBinder(data, sendFd_);
194}
195
196int32_t LocalSocketPair::ReceiveToBinder(MessageParcel &data)
197{
198    return SendFdToBinder(data, receiveFd_);
199}
200
201// internal interface
202void LocalSocketPair::CloseFd(int32_t &fd)
203{
204    if (fd != INVALID_FD) {
205        close(fd);
206        fd = INVALID_FD;
207    }
208}
209}
210
211