1/*
2 * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include "ipc_unix_socket.h"
16
17#include <poll.h>
18#include <sys/socket.h>
19#include <sys/un.h>
20#include <unistd.h>
21
22#include "hhlog.h"
23
24namespace OHOS {
25namespace Developtools {
26namespace Hiebpf {
27IpcUnixSocketServer::IpcUnixSocketServer() {}
28
29IpcUnixSocketServer::~IpcUnixSocketServer()
30{
31    Stop();
32}
33
34bool IpcUnixSocketServer::Start(const std::string &pathname)
35{
36    CHECK_TRUE(serverFd_ == -1, false, "Unix Socket Server is running");
37
38    serverFd_ = socket(AF_UNIX, SOCK_STREAM, 0);
39    CHECK_TRUE(serverFd_ != -1, false, "create Unix Socket Server failed, %d: %s", errno, strerror(errno));
40
41    unlink(pathname.c_str());
42    struct sockaddr_un addr = {0};
43    addr.sun_family = AF_UNIX;
44    std::copy(pathname.c_str(), pathname.c_str() + pathname.size() + 1, addr.sun_path);
45    if (bind(serverFd_, (struct sockaddr*)&addr, sizeof(sockaddr_un)) != 0) {
46        HHLOGE(true, "bind failed, Unix Socket(%s), %d: %s", pathname.c_str(), errno, strerror(errno));
47        close(serverFd_);
48        return false;
49    }
50    if (listen(serverFd_, UNIX_SOCKET_LISTEN_COUNT) != 0) {
51        HHLOGE(true, "listen failed, Unix Socket(%s), %d: %s", pathname.c_str(), errno, strerror(errno));
52        close(serverFd_);
53        unlink(pathname.c_str());
54        return false;
55    }
56    pathName_ = pathname;
57
58    isRunning_ = true;
59    handleThread_ = std::thread([this] { this->HandleThreadLoop(); });
60    return true;
61}
62
63bool IpcUnixSocketServer::Stop()
64{
65    isRunning_ = false;
66    if (serverFd_ != -1) {
67        close(serverFd_);
68        serverFd_ = -1;
69    }
70    if (clientFd_ != -1) {
71        close(clientFd_);
72        clientFd_ = -1;
73    }
74    if (handleThread_.joinable()) {
75        handleThread_.join();
76    }
77    unlink(pathName_.c_str());
78    return true;
79}
80
81bool IpcUnixSocketServer::SendMessage(const void *buf, size_t size)
82{
83    CHECK_TRUE(clientFd_ != -1, false, "no available Unix Socket");
84
85    CHECK_TRUE(send(clientFd_, buf, size, 0) != -1, false,
86               "send failed, Unix Socket(%d) %zu bytes, %d: %s", clientFd_, size, errno, strerror(errno));
87    return true;
88}
89
90void IpcUnixSocketServer::HandleThreadLoop()
91{
92    while (isRunning_) {
93        struct pollfd pollFd {serverFd_, POLLIN, 0};
94        const int timeout = 1000;
95        int polled = TEMP_FAILURE_RETRY(poll(&pollFd, 1, timeout));
96        if (polled == 0) { // timeout
97            continue;
98        } else if (polled < 0 || !(pollFd.revents & POLLIN)) {
99            HHLOGE(true, "poll failed, Unix Socket(%d), %d: %s", serverFd_, errno, strerror(errno));
100            close(serverFd_);
101            serverFd_ = -1;
102            break;
103        }
104
105        clientFd_ = accept(serverFd_, nullptr, nullptr);
106        if (clientFd_ == -1) {
107            HHLOGE(true, "accept failed, Unix Socket(%d), %d: %s", serverFd_, errno, strerror(errno));
108            continue;
109        }
110
111        while (isRunning_ && clientFd_ != -1) {
112            uint8_t buf[UNIX_SOCKET_BUFFER_SIZE] = {0};
113            int recvSize = recv(clientFd_, buf, UNIX_SOCKET_BUFFER_SIZE, 0);
114            if (recvSize > 0) {
115                if (handleMessageFn_) {
116                    handleMessageFn_(buf, recvSize);
117                }
118                continue;
119            } else if (recvSize == 0) {
120                HHLOGE(true, "recv failed, peer has closed");
121            } else {
122                HHLOGE(true, "recv failed, Unix Socket(%d), %d: %s", clientFd_, errno, strerror(errno));
123            }
124            close(clientFd_);
125            clientFd_ = -1;
126        }
127    }
128}
129
130IpcUnixSocketClient::IpcUnixSocketClient() {}
131
132IpcUnixSocketClient::~IpcUnixSocketClient()
133{
134    Disconnect();
135}
136
137bool IpcUnixSocketClient::Connect(const std::string &pathname)
138{
139    CHECK_TRUE(sockFd_ == -1, false, "Unix Socket has connected");
140
141    sockFd_ = socket(AF_UNIX, SOCK_STREAM, 0);
142    CHECK_TRUE(sockFd_ != -1, false, "create Unix Socket Server failed, %d: %s", errno, strerror(errno));
143
144    struct sockaddr_un addr = {0};
145    addr.sun_family = AF_UNIX;
146    std::copy(pathname.c_str(), pathname.c_str() + pathname.size() + 1, addr.sun_path);
147    if (connect(sockFd_, (struct sockaddr*)&addr, sizeof(sockaddr_un)) == -1) {
148        HHLOGE(true, "connect failed, %d: %s", errno, strerror(errno));
149        sockFd_ = -1;
150        return false;
151    }
152
153    return true;
154}
155
156void IpcUnixSocketClient::Disconnect()
157{
158    if (sockFd_ != -1) {
159        close(sockFd_);
160        sockFd_ = -1;
161    }
162}
163
164bool IpcUnixSocketClient::SendMessage(const void *buf, size_t size)
165{
166    CHECK_TRUE(sockFd_ != -1, false, "Unix Socket disconnected");
167
168    if (send(sockFd_, buf, size, 0) != -1) {
169        return true;
170    }
171    HHLOGE(true, "send failed, Unix Socket(%d), %d: %s", sockFd_, errno, strerror(errno));
172    return false;
173}
174
175bool IpcUnixSocketClient::RecvMessage(void *buf, size_t &size, uint32_t timeout)
176{
177    CHECK_TRUE(sockFd_ != -1, false, "Unix Socket disconnected");
178
179    struct pollfd pollFd {sockFd_, POLLIN | POLLERR | POLLHUP, 0};
180    int polled = poll(&pollFd, 1, timeout);
181    if (polled == 0) { // timeout
182        size = 0;
183        return true;
184    } else if (polled < 0 || !(pollFd.revents & POLLIN)) {
185        HHLOGE(true, "poll failed, Unix Socket(%d), %d: %s", sockFd_, errno, strerror(errno));
186        return false;
187    }
188
189    int recvSize = recv(sockFd_, buf, size, 0);
190    if (recvSize > 0) {
191        size = static_cast<size_t>(recvSize);
192        return true;
193    } else if (recvSize == 0) {
194        HHLOGE(true, "recv failed, peer has closed");
195    } else {
196        HHLOGE(true, "recv failed, Unix Socket(%d), %d: %s", sockFd_, errno, strerror(errno));
197    }
198
199    return false;
200}
201} // namespace Hiebpf
202} // namespace Developtools
203} // namespace OHOS
204