1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "socket_client.h"
17
18#include "event_handler.h"
19
20#include "devicestatus_define.h"
21#include "intention_identity.h"
22#include "socket_params.h"
23#include "time_cost_chk.h"
24
25#undef LOG_TAG
26#define LOG_TAG "SocketClient"
27
28namespace OHOS {
29namespace Msdp {
30namespace DeviceStatus {
31namespace {
32const std::string THREAD_NAME { "os_ClientEventHandler" };
33}
34
35SocketClient::SocketClient(std::shared_ptr<ITunnelClient> tunnel)
36    : tunnel_(tunnel)
37{
38    auto runner = AppExecFwk::EventRunner::Create(THREAD_NAME);
39    eventHandler_ = std::make_shared<AppExecFwk::EventHandler>(runner);
40}
41
42bool SocketClient::RegisterEvent(MessageId id, std::function<int32_t(const StreamClient&, NetPacket&)> callback)
43{
44    std::lock_guard guard(lock_);
45    auto [_, inserted] = callbacks_.emplace(id, callback);
46    return inserted;
47}
48
49void SocketClient::Start()
50{
51    CALL_DEBUG_ENTER;
52    Reconnect();
53}
54
55void SocketClient::Stop()
56{}
57
58bool SocketClient::Connect()
59{
60    CALL_DEBUG_ENTER;
61    if (socket_ != nullptr) {
62        return true;
63    }
64    auto socket = SocketConnection::Connect(
65        [this] { return this->Socket(); },
66        [this](NetPacket &pkt) { this->OnPacket(pkt); },
67        [this] { this->OnDisconnected(); });
68    CHKPF(socket);
69    CHKPF(eventHandler_);
70    auto errCode = eventHandler_->AddFileDescriptorListener(socket->GetFd(),
71        AppExecFwk::FILE_DESCRIPTOR_INPUT_EVENT, socket, "DeviceStatusTask");
72    if (errCode != ERR_OK) {
73        FI_HILOGE("AddFileDescriptorListener(%{public}d) failed (%{public}u)", socket->GetFd(), errCode);
74        return false;
75    }
76    socket_ = socket;
77    FI_HILOGD("SocketClient started successfully");
78    return true;
79}
80
81int32_t SocketClient::Socket()
82{
83    CALL_DEBUG_ENTER;
84    std::shared_ptr<ITunnelClient> tunnel = tunnel_.lock();
85    CHKPR(tunnel, RET_ERR);
86    AllocSocketPairParam param { GetProgramName(), CONNECT_MODULE_TYPE_FI_CLIENT };
87    AllocSocketPairReply reply;
88
89    int32_t ret = tunnel->Control(Intention::SOCKET, SocketAction::SOCKET_ACTION_CONNECT, param, reply);
90    if (ret != RET_OK) {
91        FI_HILOGE("ITunnelClient::Control fail");
92        return -1;
93    }
94    FI_HILOGD("Connected to intention service (%{public}d)", reply.socketFd);
95    return reply.socketFd;
96}
97
98void SocketClient::OnPacket(NetPacket &pkt)
99{
100    CALL_DEBUG_ENTER;
101    std::lock_guard guard(lock_);
102    OnMsgHandler(*this, pkt);
103}
104
105void SocketClient::OnDisconnected()
106{
107    CALL_DEBUG_ENTER;
108    std::lock_guard guard(lock_);
109    if (socket_ != nullptr) {
110        eventHandler_->RemoveFileDescriptorListener(socket_->GetFd());
111        eventHandler_->RemoveAllEvents();
112        socket_.reset();
113    }
114    if (!eventHandler_->PostTask([this] { this->Reconnect(); }, CLIENT_RECONNECT_COOLING_TIME)) {
115        FI_HILOGE("Failed to post reconnection task");
116    }
117}
118
119void SocketClient::Reconnect()
120{
121    std::lock_guard guard(lock_);
122    if (Connect()) {
123        return;
124    }
125    if (!eventHandler_->PostTask([this] { this->Reconnect(); }, CLIENT_RECONNECT_COOLING_TIME)) {
126        FI_HILOGE("Failed to post reconnection task");
127    }
128}
129
130void SocketClient::OnMsgHandler(const StreamClient &client, NetPacket &pkt)
131{
132    CALL_DEBUG_ENTER;
133    MessageId id = pkt.GetMsgId();
134    TimeCostChk chk("SocketClient::OnMsgHandler", "overtime 300(us)", MAX_OVER_TIME, id);
135    auto iter = callbacks_.find(id);
136    if (iter == callbacks_.end()) {
137        FI_HILOGE("Unknown msg id:%{public}d", id);
138        return;
139    }
140    int32_t ret = iter->second(client, pkt);
141    if (ret < 0) {
142        FI_HILOGE("Msg handling failed, id:%{public}d, ret:%{public}d", id, ret);
143    }
144}
145} // namespace DeviceStatus
146} // namespace Msdp
147} // namespace OHOS