1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "stream_server.h"
17
18#include <cinttypes>
19#include <list>
20
21#include <sys/socket.h>
22
23#include "devicestatus_service.h"
24#include "fi_log.h"
25
26#undef LOG_TAG
27#define LOG_TAG "StreamServer"
28
29namespace OHOS {
30namespace Msdp {
31namespace DeviceStatus {
32
33StreamServer::~StreamServer()
34{
35    CALL_DEBUG_ENTER;
36    UdsStop();
37}
38
39void StreamServer::UdsStop()
40{
41    if (epollFd_ != -1) {
42        if (close(epollFd_) < 0) {
43            FI_HILOGE("Close epoll fd failed, error:%{public}s, epollFd_:%{public}d", strerror(errno), epollFd_);
44        }
45        epollFd_ = -1;
46    }
47
48    for (const auto &item : sessionss_) {
49        item.second->Close();
50    }
51    sessionss_.clear();
52}
53
54int32_t StreamServer::GetClientFd(int32_t pid) const
55{
56    auto it = idxPids_.find(pid);
57    if (it == idxPids_.end()) {
58        return INVALID_FD;
59    }
60    return it->second;
61}
62
63int32_t StreamServer::GetClientPid(int32_t fd) const
64{
65    auto it = sessionss_.find(fd);
66    if (it == sessionss_.end()) {
67        return INVALID_PID;
68    }
69    return it->second->GetPid();
70}
71
72bool StreamServer::SendMsg(int32_t fd, NetPacket &pkt)
73{
74    if (fd < 0) {
75        FI_HILOGE("The fd is less than 0");
76        return false;
77    }
78    auto ses = GetSession(fd);
79    if (ses == nullptr) {
80        FI_HILOGE("The fd:%{public}d not found, The message was discarded, errCode:%{public}d",
81            fd, SESSION_NOT_FOUND);
82        return false;
83    }
84    return ses->SendMsg(pkt);
85}
86
87void StreamServer::Multicast(const std::vector<int32_t> &fdList, NetPacket &pkt)
88{
89    for (const auto &item : fdList) {
90        SendMsg(item, pkt);
91    }
92}
93
94int32_t StreamServer::AddSocketPairInfo(const std::string &programName, int32_t moduleType, int32_t uid, int32_t pid,
95    int32_t &serverFd, int32_t &toReturnClientFd, int32_t &tokenType)
96{
97    CALL_DEBUG_ENTER;
98    int32_t sockFds[2] = { -1 };
99
100    if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockFds) != 0) {
101        FI_HILOGE("Call socketpair failed, errno:%{public}d", errno);
102        return RET_ERR;
103    }
104    serverFd = sockFds[0];
105    toReturnClientFd = sockFds[1];
106    if (serverFd < 0 || toReturnClientFd < 0) {
107        FI_HILOGE("Call fcntl failed, errno:%{public}d", errno);
108        return RET_ERR;
109    }
110    int32_t setSockOptResult = SetSockOpt(serverFd, toReturnClientFd, tokenType);
111    if (RET_OK != setSockOptResult) {
112        return setSockOptResult;
113    }
114    SessionPtr sess = nullptr;
115    sess = std::make_shared<StreamSession>(programName, moduleType, serverFd, uid, pid);
116    sess->SetTokenType(tokenType);
117    if (!AddSession(sess)) {
118        FI_HILOGE("AddSession fail errCode:%{public}d", ADD_SESSION_FAIL);
119        return CloseFd(serverFd, toReturnClientFd);
120    }
121    if (AddEpoll(EPOLL_EVENT_SOCKET, serverFd) != RET_OK) {
122        FI_HILOGE("epoll_ctl EPOLL_CTL_ADD failed, errCode:%{public}d", EPOLL_MODIFY_FAIL);
123        return CloseFd(serverFd, toReturnClientFd);
124    }
125    OnConnected(sess);
126    return RET_OK;
127}
128
129int32_t StreamServer::SetSockOpt(int32_t &serverFd, int32_t &toReturnClientFd, int32_t &tokenType)
130{
131    CALL_DEBUG_ENTER;
132    static constexpr size_t bufferSize = 32 * 1024;
133    static constexpr size_t nativeBufferSize = 64 * 1024;
134
135    if (setsockopt(serverFd, SOL_SOCKET, SO_SNDBUF, &bufferSize, sizeof(bufferSize)) != 0) {
136        FI_HILOGE("setsockopt serverFd failed, errno:%{public}d", errno);
137        return CloseFd(serverFd, toReturnClientFd);
138    }
139    if (setsockopt(serverFd, SOL_SOCKET, SO_RCVBUF, &bufferSize, sizeof(bufferSize)) != 0) {
140        FI_HILOGE("setsockopt serverFd failed, errno:%{public}d", errno);
141        return CloseFd(serverFd, toReturnClientFd);
142    }
143    if (tokenType == TokenType::TOKEN_NATIVE) {
144        if (setsockopt(toReturnClientFd, SOL_SOCKET, SO_SNDBUF, &nativeBufferSize, sizeof(nativeBufferSize)) != 0) {
145            FI_HILOGE("setsockopt toReturnClientFd failed, errno:%{public}d", errno);
146            return CloseFd(serverFd, toReturnClientFd);
147        }
148        if (setsockopt(toReturnClientFd, SOL_SOCKET, SO_RCVBUF, &nativeBufferSize, sizeof(nativeBufferSize)) != 0) {
149            FI_HILOGE("setsockopt toReturnClientFd failed, errno:%{public}d", errno);
150            return CloseFd(serverFd, toReturnClientFd);
151        }
152    } else {
153        if (setsockopt(toReturnClientFd, SOL_SOCKET, SO_SNDBUF, &bufferSize, sizeof(bufferSize)) != 0) {
154            FI_HILOGE("setsockopt toReturnClientFd failed, errno:%{public}d", errno);
155            return CloseFd(serverFd, toReturnClientFd);
156        }
157        if (setsockopt(toReturnClientFd, SOL_SOCKET, SO_RCVBUF, &bufferSize, sizeof(bufferSize)) != 0) {
158            FI_HILOGE("setsockopt toReturnClientFd failed, errno:%{public}d", errno);
159            return CloseFd(serverFd, toReturnClientFd);
160        }
161    }
162    return RET_OK;
163}
164
165int32_t StreamServer::CloseFd(int32_t &serverFd, int32_t &toReturnClientFd)
166{
167    if (close(serverFd) < 0) {
168        FI_HILOGE("Close server fd failed, error:%{public}s, serverFd:%{public}d", strerror(errno), serverFd);
169    }
170    serverFd = -1;
171    if (close(toReturnClientFd) < 0) {
172        FI_HILOGE("Close fd failed, error:%{public}s, toReturnClientFd:%{public}d", strerror(errno), toReturnClientFd);
173    }
174    toReturnClientFd = -1;
175    return RET_ERR;
176}
177
178void StreamServer::SetRecvFun(MsgServerFunCallback fun)
179{
180    recvFun_ = fun;
181}
182
183void StreamServer::ReleaseSession(int32_t fd, epoll_event &ev)
184{
185    auto secPtr = GetSession(fd);
186    if (secPtr != nullptr) {
187        OnDisconnected(secPtr);
188        DelSession(fd);
189    }
190    if (ev.data.ptr) {
191        free(ev.data.ptr);
192        ev.data.ptr = nullptr;
193    }
194    if (auto it = circleBufs_.find(fd); it != circleBufs_.end()) {
195        circleBufs_.erase(it);
196    }
197    auto DeviceStatusService = DeviceStatus::DelayedSpSingleton<DeviceStatus::DeviceStatusService>::GetInstance();
198    DeviceStatusService->DelEpoll(EPOLL_EVENT_SOCKET, fd);
199    if (close(fd) < 0) {
200        FI_HILOGE("Close fd failed, error:%{public}s, fd:%{public}d", strerror(errno), fd);
201    }
202}
203
204void StreamServer::OnPacket(int32_t fd, NetPacket &pkt)
205{
206    auto sess = GetSession(fd);
207    CHKPV(sess);
208    recvFun_(sess, pkt);
209}
210
211void StreamServer::OnEpollRecv(int32_t fd, epoll_event &ev)
212{
213    if (fd < 0) {
214        FI_HILOGE("Invalid fd:%{public}d", fd);
215        return;
216    }
217    auto& buf = circleBufs_[fd];
218    char szBuf[MAX_PACKET_BUF_SIZE] = { 0 };
219    for (int32_t i = 0; i < MAX_RECV_LIMIT; i++) {
220        ssize_t size = recv(fd, szBuf, MAX_PACKET_BUF_SIZE, MSG_DONTWAIT | MSG_NOSIGNAL);
221        if (size > 0) {
222            if (!buf.Write(szBuf, size)) {
223                FI_HILOGW("Write data failed, size:%{public}zd", size);
224            }
225            OnReadPackets(buf, [this, fd](NetPacket &pkt) { this->OnPacket(fd, pkt); });
226        } else if (size < 0) {
227            if (errno == EAGAIN || errno == EINTR || errno == EWOULDBLOCK) {
228                FI_HILOGD("Continue for errno EAGAIN|EINTR|EWOULDBLOCK size:%{public}zd errno:%{public}d",
229                    size, errno);
230                continue;
231            }
232            FI_HILOGE("Recv return %{public}zd, errno:%{public}d", size, errno);
233            break;
234        } else {
235            FI_HILOGE("The client side disconnect with the server, size:0, errno:%{public}d", errno);
236            ReleaseSession(fd, ev);
237            break;
238        }
239        if (static_cast<size_t>(size) < MAX_PACKET_BUF_SIZE) {
240            break;
241        }
242    }
243}
244
245void StreamServer::OnEpollEvent(epoll_event &ev)
246{
247    CHKPV(ev.data.ptr);
248    int32_t fd = *static_cast<int32_t*>(ev.data.ptr);
249    if (fd < 0) {
250        FI_HILOGE("The fd less than 0, errCode:%{public}d", PARAM_INPUT_INVALID);
251        return;
252    }
253    if ((ev.events & EPOLLERR) || (ev.events & EPOLLHUP)) {
254        FI_HILOGI("EPOLLERR or EPOLLHUP, fd:%{public}d, ev.events:0x%{public}x", fd, ev.events);
255        ReleaseSession(fd, ev);
256    } else if (ev.events & EPOLLIN) {
257        OnEpollRecv(fd, ev);
258    }
259}
260
261void StreamServer::DumpSession(const std::string &title)
262{
263    FI_HILOGD("in %{public}s:%{public}s", __func__, title.c_str());
264    int32_t i = 0;
265    for (auto &[key, value] : sessionss_) {
266        CHKPV(value);
267        i++;
268    }
269}
270
271SessionPtr StreamServer::GetSession(int32_t fd) const
272{
273    auto it = sessionss_.find(fd);
274    if (it == sessionss_.end()) {
275        FI_HILOGE("Session not found, fd:%{public}d", fd);
276        return nullptr;
277    }
278    CHKPP(it->second);
279    return it->second->GetSharedPtr();
280}
281
282SessionPtr StreamServer::GetSessionByPid(int32_t pid) const
283{
284    int32_t fd = GetClientFd(pid);
285    if (fd <= 0) {
286        FI_HILOGE("Session not found, pid:%{public}d", pid);
287        return nullptr;
288    }
289    return GetSession(fd);
290}
291
292bool StreamServer::AddSession(SessionPtr ses)
293{
294    CHKPF(ses);
295    FI_HILOGI("pid:%{public}d, fd:%{public}d", ses->GetPid(), ses->GetFd());
296    int32_t fd = ses->GetFd();
297    if (fd < 0) {
298        FI_HILOGE("The fd is less than 0");
299        return false;
300    }
301    int32_t pid = ses->GetPid();
302    if (pid <= 0) {
303        FI_HILOGE("Get process failed");
304        return false;
305    }
306    if (sessionss_.size() > MAX_SESSION_ALARM) {
307        FI_HILOGE("Too many clients, Warning Value:%{public}zu, Current Value:%{public}zu",
308            MAX_SESSION_ALARM, sessionss_.size());
309        return false;
310    }
311    DumpSession("AddSession");
312    idxPids_[pid] = fd;
313    sessionss_[fd] = ses;
314    FI_HILOGI("Add session end");
315    return true;
316}
317
318void StreamServer::DelSession(int32_t fd)
319{
320    CALL_DEBUG_ENTER;
321    FI_HILOGI("fd:%{public}d", fd);
322    if (fd < 0) {
323        FI_HILOGE("The fd less than 0, errCode:%{public}d", PARAM_INPUT_INVALID);
324        return;
325    }
326    int32_t pid = GetClientPid(fd);
327    if (pid > 0) {
328        idxPids_.erase(pid);
329    }
330    auto it = sessionss_.find(fd);
331    if (it != sessionss_.end()) {
332        NotifySessionDeleted(it->second);
333        sessionss_.erase(it);
334    }
335    DumpSession("DelSession");
336}
337
338void StreamServer::AddSessionDeletedCallback(int32_t pid, std::function<void(SessionPtr)> callback)
339{
340    CALL_DEBUG_ENTER;
341    auto it = callbacks_.find(pid);
342    if (it != callbacks_.end()) {
343        FI_HILOGW("Deleted session already exists");
344        return;
345    }
346    callbacks_[pid] = callback;
347}
348
349void StreamServer::NotifySessionDeleted(SessionPtr ses)
350{
351    CALL_DEBUG_ENTER;
352    auto it = callbacks_.find(ses->GetPid());
353    if (it != callbacks_.end()) {
354        it->second(ses);
355        callbacks_.erase(it);
356    }
357}
358} // namespace DeviceStatus
359} // namespace Msdp
360} // namespace OHOS
361