1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <algorithm>
17 #include <cinttypes>
18 #include <chrono>
19 #include <thread>
20 
21 #include "common.h"
22 #include "tmessenger.h"
23 
24 namespace OHOS {
25 static constexpr uint32_t WAIT_RESP_TIME = 1;
26 
Encode() const27 std::string Request::Encode() const
28 {
29     return std::to_string(static_cast<int32_t>(cmd_));
30 }
31 
Decode(const std::string &data)32 std::shared_ptr<Request> Request::Decode(const std::string &data)
33 {
34     if (data.empty()) {
35         LOGE("the data is empty");
36         return nullptr;
37     }
38 
39     Cmd cmd = static_cast<Cmd>(std::stoi(data));
40     if (cmd < Cmd::QUERY_RESULT || cmd > Cmd::QUERY_RESULT) {
41         LOGE("invalid cmd=%d", static_cast<int32_t>(cmd));
42         return nullptr;
43     }
44     return std::make_shared<Request>(cmd);
45 }
46 
Encode() const47 std::string Response::Encode() const
48 {
49     std::string data = std::to_string(isEncrypt_ ? 1 : 0);
50     return data + SEPARATOR + recvData_;
51 }
52 
Decode(const std::string &data)53 std::shared_ptr<Response> Response::Decode(const std::string &data)
54 {
55     if (data.empty()) {
56         LOGE("the data is empty");
57         return nullptr;
58     }
59 
60     size_t pos = data.find(SEPARATOR);
61     if (pos == std::string::npos) {
62         LOGE("can not find separator in the string data");
63         return nullptr;
64     }
65 
66     int32_t isEncryptVal = static_cast<int32_t>(std::stoi(data.substr(0, pos)));
67     bool isEncrypt = (isEncryptVal == 1);
68     std::string recvData = data.substr(pos + 1);
69 
70     return std::make_shared<Response>(isEncrypt, recvData);
71 }
72 
~Message()73 Message::~Message()
74 {
75     if (msgType_ == MsgType::MSG_SEQ && request != nullptr) {
76         delete request;
77     }
78     if (msgType_ == MsgType::MSG_RSP && response != nullptr) {
79         delete response;
80     }
81 }
82 
Encode() const83 std::string Message::Encode() const
84 {
85     std::string data = std::to_string(static_cast<int32_t>(msgType_));
86     switch (msgType_) {
87         case MsgType::MSG_SEQ:
88             return request == nullptr ? "" : data + SEPARATOR + request->Encode();
89         case MsgType::MSG_RSP:
90             return response == nullptr ? "" : data + SEPARATOR + response->Encode();
91         default:
92             LOGE("invalid msgType=%d", static_cast<int32_t>(msgType_));
93             return "";
94     }
95 }
96 
Decode(const std::string &data)97 std::shared_ptr<Message> Message::Decode(const std::string &data)
98 {
99     size_t pos = data.find(SEPARATOR);
100     if (pos == std::string::npos) {
101         return nullptr;
102     }
103 
104     MsgType msgType = static_cast<MsgType>(std::stoi(data.substr(0, pos)));
105     switch (msgType) {
106         case MsgType::MSG_SEQ: {
107             std::shared_ptr<Request> req = Request::Decode(data.substr(pos + 1));
108             if (req == nullptr) {
109                 return nullptr;
110             }
111             return std::make_shared<Message>(*req);
112         }
113         case MsgType::MSG_RSP: {
114             std::shared_ptr<Response> rsp = Response::Decode(data.substr(pos + 1));
115             if (rsp == nullptr) {
116                 return nullptr;
117             }
118             return std::make_shared<Message>(*rsp);
119         }
120         default:
121             LOGE("invalid msgType=%d", static_cast<int32_t>(msgType));
122             return nullptr;
123     }
124 }
125 
Open( const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer)126 int32_t TMessenger::Open(
127     const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer)
128 {
129     isServer_ = isServer;
130     return isServer_ ? StartListen(pkgName, myName) : StartConnect(pkgName, myName, peerName);
131 }
132 
Close()133 void TMessenger::Close()
134 {
135     if (socket_ > 0) {
136         Shutdown(socket_);
137         socket_ = -1;
138     }
139 
140     if (listenSocket_ > 0) {
141         Shutdown(listenSocket_);
142         listenSocket_ = -1;
143     }
144 
145     pkgName_.clear();
146     myName_.clear();
147     peerName_.clear();
148     peerNetworkId_.clear();
149     msgList_.clear();
150 }
151 
StartListen(const std::string &pkgName, const std::string &myName)152 int32_t TMessenger::StartListen(const std::string &pkgName, const std::string &myName)
153 {
154     if (listenSocket_ > 0) {
155         return SOFTBUS_OK;
156     }
157 
158     SocketInfo info = {
159         .pkgName = (char *)(pkgName.c_str()),
160         .name = (char *)(myName.c_str()),
161     };
162     int32_t socket = Socket(info);
163     if (socket <= 0) {
164         LOGE("failed to create socket, ret=%d", socket);
165         return socket;
166     }
167     LOGI("create listen socket=%d", socket);
168 
169     QosTV qosInfo[] = {
170         {.qos = QOS_TYPE_MIN_BW,       .value = 80  },
171         { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
172         { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
173     };
174     static ISocketListener listener = {
175         .OnBind = TMessenger::OnBind,
176         .OnMessage = TMessenger::OnMessage,
177         .OnShutdown = TMessenger::OnShutdown,
178     };
179 
180     int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
181     if (ret != SOFTBUS_OK) {
182         LOGE("failed to listen, socket=%d", socket);
183         Shutdown(socket);
184         return ret;
185     }
186     listenSocket_ = socket;
187     pkgName_ = pkgName;
188     myName_ = myName;
189     return SOFTBUS_OK;
190 }
191 
StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName)192 int32_t TMessenger::StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName)
193 {
194     if (socket_ > 0) {
195         return SOFTBUS_OK;
196     }
197 
198     SocketInfo info = {
199         .pkgName = const_cast<char *>(pkgName.c_str()),
200         .name = const_cast<char *>(myName.c_str()),
201         .peerName = const_cast<char *>(peerName.c_str()),
202         .peerNetworkId = nullptr,
203         .dataType = DATA_TYPE_MESSAGE,
204     };
205     info.peerNetworkId = OHOS::WaitOnLineAndGetNetWorkId();
206 
207     int32_t socket = Socket(info);
208     if (socket <= 0) {
209         LOGE("failed to create socket, ret=%d", socket);
210         return socket;
211     }
212     LOGI("create bind socket=%d", socket);
213 
214     QosTV qosInfo[] = {
215         {.qos = QOS_TYPE_MIN_BW,       .value = 80  },
216         { .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
217         { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
218     };
219 
220     static ISocketListener listener = {
221         .OnMessage = OnMessage,
222         .OnShutdown = OnShutdown,
223     };
224 
225     int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
226     if (ret != SOFTBUS_OK) {
227         LOGE("failed to bind, socket=%d, ret=%d", socket, ret);
228         Shutdown(socket);
229         return ret;
230     }
231 
232     pkgName_ = pkgName;
233     myName_ = myName;
234     peerNetworkId_ = info.peerNetworkId;
235     peerName_ = peerName;
236     socket_ = socket;
237     return SOFTBUS_OK;
238 }
239 
OnBind(int32_t socket, PeerSocketInfo info)240 void TMessenger::OnBind(int32_t socket, PeerSocketInfo info)
241 {
242     TMessenger::GetInstance().SetConnectSocket(socket, info);
243 }
244 
OnMessage(int32_t socket, const void *data, uint32_t dataLen)245 void TMessenger::OnMessage(int32_t socket, const void *data, uint32_t dataLen)
246 {
247     std::string result(static_cast<const char *>(data), dataLen);
248     TMessenger::GetInstance().OnMessageRecv(result);
249 }
250 
OnShutdown(int32_t socket, ShutdownReason reason)251 void TMessenger::OnShutdown(int32_t socket, ShutdownReason reason)
252 {
253     TMessenger::GetInstance().CloseSocket(socket);
254 }
255 
SetConnectSocket(int32_t socket, PeerSocketInfo info)256 void TMessenger::SetConnectSocket(int32_t socket, PeerSocketInfo info)
257 {
258     if (socket_ > 0) {
259         return;
260     }
261 
262     socket_ = socket;
263     peerName_ = info.name;
264     peerNetworkId_ = info.networkId;
265 }
266 
OnMessageRecv(const std::string &result)267 void TMessenger::OnMessageRecv(const std::string &result)
268 {
269     std::shared_ptr<Message> msg = Message::Decode(result);
270     if (msg == nullptr) {
271         LOGE("receive invalid message");
272         return;
273     }
274 
275     switch (msg->msgType_) {
276         case Message::MsgType::MSG_SEQ: {
277             OnRequest();
278             break;
279         }
280         case Message::MsgType::MSG_RSP: {
281             std::unique_lock<std::mutex> lock(recvMutex_);
282             msgList_.push_back(msg);
283             lock.unlock();
284             recvCond_.notify_one();
285             break;
286         }
287         default:
288             break;
289     }
290 }
291 
OnRequest()292 void TMessenger::OnRequest()
293 {
294     std::thread t([&] {
295         std::this_thread::sleep_for(std::chrono::seconds(WAIT_RESP_TIME));
296         std::shared_ptr<Response> resp = onQuery_();
297         Message msg { *resp };
298         int32_t ret = Send(msg);
299         if (ret != SOFTBUS_OK) {
300             LOGE("failed to send response");
301         }
302     });
303     t.detach();
304 }
305 
CloseSocket(int32_t socket)306 void TMessenger::CloseSocket(int32_t socket)
307 {
308     if (socket_ == socket) {
309         Shutdown(socket_);
310         socket_ = -1;
311     }
312 }
313 
QueryResult(uint32_t timeout)314 std::shared_ptr<Response> TMessenger::QueryResult(uint32_t timeout)
315 {
316     Request req { Request::Cmd::QUERY_RESULT };
317     Message msg { req };
318     int32_t ret = Send(msg);
319     if (ret != SOFTBUS_OK) {
320         LOGE("failed to query result, ret=%d", ret);
321         return nullptr;
322     }
323 
324     return WaitResponse(timeout);
325 }
326 
Send(const Message &msg)327 int32_t TMessenger::Send(const Message &msg)
328 {
329     std::string data = msg.Encode();
330     if (data.empty()) {
331         LOGE("the data is empty");
332         return SOFTBUS_MEM_ERR;
333     }
334 
335     int32_t ret = SendMessage(socket_, data.c_str(), data.size());
336     if (ret != SOFTBUS_OK) {
337         LOGE("failed to send message, socket=%d, ret=%d", socket_, ret);
338     }
339     return ret;
340 }
341 
WaitResponse(uint32_t timeout)342 std::shared_ptr<Response> TMessenger::WaitResponse(uint32_t timeout)
343 {
344     std::unique_lock<std::mutex> lock(recvMutex_);
345     std::shared_ptr<Response> rsp = nullptr;
346     if (recvCond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
347             rsp = GetMessageFromRecvList(Message::MsgType::MSG_RSP);
348             return rsp != nullptr;
349         })) {
350         return rsp;
351     }
352     LOGE("no result received");
353     return nullptr;
354 }
355 
GetMessageFromRecvList(Message::MsgType type)356 std::shared_ptr<Response> TMessenger::GetMessageFromRecvList(Message::MsgType type)
357 {
358     auto it = std::find_if(msgList_.begin(), msgList_.end(), [type] (const std::shared_ptr<Message> &it) {
359         return it->msgType_ == type;
360     });
361 
362     if (it == msgList_.end() || *it == nullptr) {
363         return nullptr;
364     }
365 
366     const Response *rsp = (*it)->response;
367     if (rsp == nullptr) {
368         msgList_.erase(it);
369         return nullptr;
370     }
371 
372     std::shared_ptr<Response> resp = std::make_shared<Response>(*rsp);
373     msgList_.erase(it);
374     return resp;
375 }
376 
RegisterOnQuery(TMessenger::OnQueryCallback callback)377 void TMessenger::RegisterOnQuery(TMessenger::OnQueryCallback callback)
378 {
379     onQuery_ = callback;
380 }
381 } // namespace OHOS
382