1/*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "common/log_wrapper.h"
17#include "define.h"
18#include "frame_builder.h"
19#include "network.h"
20#include "websocket_base.h"
21
22#include <mutex>
23
24namespace OHOS::ArkCompiler::Toolchain {
25static std::string ToString(CloseStatusCode status)
26{
27    if (status == CloseStatusCode::NO_STATUS_CODE) {
28        return "";
29    }
30    std::string result;
31    PushNumberPerByte(result, EnumToNumber(status));
32    return result;
33}
34
35WebSocketBase::~WebSocketBase() noexcept
36{
37    if (connectionFd_ != -1) {
38        LOGW("WebSocket connection is closed while destructing the object");
39        close(connectionFd_);
40        // Reset directly in order to prevent static analyzer warnings.
41        connectionFd_ = -1;
42    }
43}
44
45// if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0
46// and the last frame will be marked as 0x1.
47// we just add the 'isLast' parameter to indicate whether it is the last frame.
48bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const
49{
50    if (connectionState_.load() != ConnectionState::OPEN) {
51        LOGE("SendReply failed, websocket not connected");
52        return false;
53    }
54
55    const auto frame = CreateFrame(isLast, frameType, message);
56    if (!SendUnderLock(frame)) {
57        LOGE("SendReply: send failed");
58        return false;
59    }
60    return true;
61}
62
63/**
64  *  The wired format of this data transmission section is described in detail through ABNFRFC5234.
65  *  When receive the message, we should decode it according the spec. The structure is as follows:
66  *     0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
67  *    +-+-+-+-+-------+-+-------------+-------------------------------+
68  *    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
69  *    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
70  *    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
71  *    | |1|2|3|       |K|             |                               |
72  *    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
73  *    |     Extended payload length continued, if payload len == 127  |
74  *    + - - - - - - - - - - - - - - - +-------------------------------+
75  *    |                               |Masking-key, if MASK set to 1  |
76  *    +-------------------------------+-------------------------------+
77  *    | Masking-key (continued)       |          Payload Data         |
78  *    +-------------------------------- - - - - - - - - - - - - - - - +
79  *    :                     Payload Data continued ...                :
80  *    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
81  *    |                     Payload Data continued ...                |
82  *    +---------------------------------------------------------------+
83  */
84
85bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const
86{
87    if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) {
88        uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0};
89        if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) {
90            LOGE("ReadPayload: Recv payloadLen == 126 failed");
91            return false;
92        }
93        wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH);
94    } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) {
95        uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0};
96        if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) {
97            LOGE("ReadPayload: Recv payloadLen == 127 failed");
98            return false;
99        }
100        wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH);
101    }
102    return DecodeMessage(wsFrame);
103}
104
105bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const
106{
107    if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) {
108        return ReadPayload(wsFrame);
109    } else {
110        LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode);
111    }
112    return true;
113}
114
115bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame)
116{
117    if (wsFrame.opcode == EnumToNumber(FrameType::PING)) {
118        // A Pong frame sent in response to a Ping frame must have identical
119        // "Application data" as found in the message body of the Ping frame
120        // being replied to.
121        // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3
122        if (!ReadPayload(wsFrame)) {
123            LOGE("Failed to read ping frame payload");
124            return false;
125        }
126        SendPongFrame(wsFrame.payload);
127    } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) {
128        // might read payload to response by echoing the status code
129        CloseConnection(CloseStatusCode::NO_STATUS_CODE);
130    }
131    return true;
132}
133
134std::string WebSocketBase::Decode()
135{
136    if (auto state = connectionState_.load(); state != ConnectionState::OPEN) {
137        LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state));
138        return "";
139    }
140
141    uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0};
142    if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) {
143        LOGE("Decode failed, client websocket disconnect");
144        CloseConnection(CloseStatusCode::UNEXPECTED_ERROR);
145        return std::string(DECODE_DISCONNECT_MSG);
146    }
147    WebSocketFrame wsFrame(recvbuf);
148    if (!ValidateIncomingFrame(wsFrame)) {
149        LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]);
150        CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
151        return std::string(DECODE_DISCONNECT_MSG);
152    }
153
154    if (IsControlFrame(wsFrame.opcode)) {
155        if (HandleControlFrame(wsFrame)) {
156            return wsFrame.payload;
157        }
158    } else if (HandleDataFrame(wsFrame)) {
159        return wsFrame.payload;
160    }
161    // Unexpected data, must close the connection.
162    CloseConnection(CloseStatusCode::PROTOCOL_ERROR);
163    return std::string(DECODE_DISCONNECT_MSG);
164}
165
166bool WebSocketBase::IsConnected() const
167{
168    return connectionState_.load() == ConnectionState::OPEN;
169}
170
171void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb)
172{
173    closeCb_ = std::move(cb);
174}
175
176void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb)
177{
178    failCb_ = std::move(cb);
179}
180
181void WebSocketBase::OnConnectionClose(ConnectionCloseReason status)
182{
183    if (status == ConnectionCloseReason::FAIL) {
184        if (failCb_) {
185            failCb_();
186        }
187    } else if (status == ConnectionCloseReason::CLOSE) {
188        if (closeCb_) {
189            closeCb_();
190        }
191    }
192}
193
194void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status)
195{
196    OnConnectionClose(status);
197
198    {
199        // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock.
200        std::shared_lock lock(connectionMutex_);
201        int err = ShutdownSocket(connectionFd_);
202        if (err != 0) {
203            LOGW("Failed to shutdown client socket, errno = %{public}d", errno);
204        }
205    }
206    {
207        // Unique lock due to close and write into `connectionFd_`.
208        // Note that `close` must be also done in critical section,
209        // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor.
210        std::unique_lock lock(connectionMutex_);
211        close(connectionFd_);
212        // Reset directly in order to prevent static analyzer warnings.
213        connectionFd_ = -1;
214    }
215
216    auto expected = ConnectionState::CLOSING;
217    if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) {
218        LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected));
219    }
220}
221
222void WebSocketBase::SendPongFrame(std::string payload) const
223{
224    const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload));
225    if (!SendUnderLock(frame)) {
226        LOGE("Decode: Send pong frame failed");
227    }
228}
229
230void WebSocketBase::SendCloseFrame(CloseStatusCode status) const
231{
232    const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status));
233    if (!SendUnderLock(frame)) {
234        LOGE("SendCloseFrame: Send close frame failed");
235    }
236}
237
238bool WebSocketBase::CloseConnection(CloseStatusCode status)
239{
240    auto expected = ConnectionState::OPEN;
241    if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) {
242        // Concurrent connection close detected, do nothing.
243        return false;
244    }
245
246    LOGI("Close connection, status = %{public}d", static_cast<int>(status));
247    SendCloseFrame(status);
248    // can close connection right after sending back close frame.
249    CloseConnectionSocket(ConnectionCloseReason::CLOSE);
250    return true;
251}
252
253int WebSocketBase::GetConnectionSocket() const
254{
255    return connectionFd_;
256}
257
258void WebSocketBase::SetConnectionSocket(int socketFd)
259{
260    connectionFd_ = socketFd;
261}
262
263std::shared_mutex &WebSocketBase::GetConnectionMutex()
264{
265    return connectionMutex_;
266}
267
268WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const
269{
270    return connectionState_.load();
271}
272
273WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState)
274{
275    return connectionState_.exchange(newState);
276}
277
278bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState)
279{
280    return connectionState_.compare_exchange_strong(expected, newState);
281}
282
283bool WebSocketBase::SendUnderLock(const std::string& message) const
284{
285    std::shared_lock lock(connectionMutex_);
286    return Send(connectionFd_, message, 0);
287}
288
289bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const
290{
291    std::shared_lock lock(connectionMutex_);
292    return Send(connectionFd_, buf, totalLen, 0);
293}
294
295bool WebSocketBase::RecvUnderLock(std::string& message) const
296{
297    std::shared_lock lock(connectionMutex_);
298    return Recv(connectionFd_, message, 0);
299}
300
301bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const
302{
303    std::shared_lock lock(connectionMutex_);
304    return Recv(connectionFd_, buf, totalLen, 0);
305}
306
307/* static */
308bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message)
309{
310    return message == DECODE_DISCONNECT_MSG;
311}
312
313#if !defined(OHOS_PLATFORM)
314/* static */
315bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
316{
317    if (timeoutLimit > 0) {
318        struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
319        if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO,
320            reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
321            LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
322            return false;
323        }
324        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO,
325            reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) {
326            LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
327            return false;
328        }
329    }
330    return true;
331}
332#else
333/* static */
334bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit)
335{
336    if (timeoutLimit > 0) {
337        struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0};
338        if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
339            LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno);
340            return false;
341        }
342        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) {
343            LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno);
344            return false;
345        }
346    }
347    return true;
348}
349#endif
350
351#if defined(WINDOWS_PLATFORM)
352/* static */
353int WebSocketBase::ShutdownSocket(int32_t fd)
354{
355    return shutdown(fd, SD_BOTH);
356}
357#else
358/* static */
359int WebSocketBase::ShutdownSocket(int32_t fd)
360{
361    return shutdown(fd, SHUT_RDWR);
362}
363#endif
364} // namespace OHOS::ArkCompiler::Toolchain
365