1/*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
17#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
18
19#include "web_socket_frame.h"
20
21#include <atomic>
22#include <functional>
23#include <shared_mutex>
24#include <type_traits>
25
26namespace OHOS::ArkCompiler::Toolchain {
27enum CloseStatusCode : uint16_t {
28    NO_STATUS_CODE = 0,
29    NORMAL = 1000,
30    SERVER_GO_AWAY = 1001,
31    PROTOCOL_ERROR = 1002,
32    UNACCEPTABLE_DATA = 1003,
33    INCONSISTENT_DATA = 1007,
34    POLICY_VIOLATION = 1008,
35    MESSAGE_TOO_BIG = 1009,
36    UNEXPECTED_ERROR = 1011,
37};
38
39class WebSocketBase {
40public:
41    using CloseConnectionCallback = std::function<void()>;
42    using FailConnectionCallback = std::function<void()>;
43
44public:
45    static bool IsDecodeDisconnectMsg(const std::string& message);
46
47    WebSocketBase() = default;
48    virtual ~WebSocketBase() noexcept;
49
50    /**
51     * @brief Receive and decode a message.
52     * Must not be called concurrently on the same connection.
53     * Safe to call concurrently with `SendReply` and `Close`.
54     * Control frames are handled according to specification with an empty string as returned value,
55     * otherwise the method returns the decoded received message.
56     * Note that this method closes the connection after receiving invalid data.
57     * This event can be checked with `IsDecodeDisconnectMsg`.
58     */
59    std::string Decode();
60
61    /**
62     * @brief Send message on current connection.
63     * Safe to call concurrently with: `SendReply`, `Decode`, `Close`.
64     * Note that the connection is not closed on transmission failures.
65     * @param message text payload.
66     * @param frameType frame type, must be either TEXT, BINARY or CONTINUATION.
67     * @param isLast flag indicating whether the message is the final.
68     * @returns true on success, false otherwise.
69     */
70    bool SendReply(const std::string& message, FrameType frameType = FrameType::TEXT, bool isLast = true) const;
71
72    /**
73     * @brief Check if connection is in `OPEN` state.
74     */
75    bool IsConnected() const;
76
77    /**
78     * @brief Set callback for calling after normal connection close.
79     * Non thread safe.
80     */
81    void SetCloseConnectionCallback(CloseConnectionCallback cb);
82
83    /**
84     * @brief Set callback for calling after closing connection on any failure.
85     * Non thread safe.
86     */
87    void SetFailConnectionCallback(FailConnectionCallback cb);
88
89    /**
90     * @brief Send `CLOSE` frame and close the connection socket.
91     * Does nothing if connection was not in `OPEN` state.
92     * @param status close status code specified in sent frame.
93     * @returns true if connection was closed, false otherwise.
94     */
95    bool CloseConnection(CloseStatusCode status);
96
97protected:
98    enum class ConnectionState : uint8_t {
99        CONNECTING,
100        OPEN,
101        CLOSING,
102        CLOSED,
103    };
104
105    enum class ConnectionCloseReason: uint8_t {
106        FAIL,
107        CLOSE,
108    };
109
110protected:
111    /**
112     * @brief Set `send` and `recv` timeout limits.
113     * @param fd socket to set timeout on.
114     * @param timeoutLimit timeout in seconds. If zero, function is no-op.
115     * @returns true on success, false otherwise.
116     */
117    static bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit);
118
119    /**
120     * @brief Shutdown socket for sends and receives.
121     * Note that the implementation of this function is platform-specific,
122     * so there is no unified way to retrieve error code returned from system call.
123     * @param fd socket file descriptor.
124     * @returns zero on success, `-1` otherwise.
125     */
126    static int ShutdownSocket(int32_t fd);
127
128    /**
129     * @brief Close the connection socket.
130     * Must be transition from `CLOSING` to `CLOSED` connection state.
131     * @param status close reason, depends which callback to execute.
132     */
133    void CloseConnectionSocket(ConnectionCloseReason status);
134
135    /**
136     * @brief Execute user-provided callbacks before closing the connection socket.
137     */
138    void OnConnectionClose(ConnectionCloseReason status);
139
140    int GetConnectionSocket() const;
141    void SetConnectionSocket(int socketFd);
142    std::shared_mutex &GetConnectionMutex();
143
144    ConnectionState GetConnectionState() const;
145    ConnectionState SetConnectionState(ConnectionState newState);
146    bool CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState);
147
148    bool HandleDataFrame(WebSocketFrame& wsFrame) const;
149    bool HandleControlFrame(WebSocketFrame& wsFrame);
150    bool ReadPayload(WebSocketFrame& wsFrame) const;
151    void SendPongFrame(std::string payload) const;
152    void SendCloseFrame(CloseStatusCode status) const;
153
154    bool SendUnderLock(const std::string& message) const;
155    bool SendUnderLock(const char* buf, size_t totalLen) const;
156    bool RecvUnderLock(std::string& message) const;
157    bool RecvUnderLock(uint8_t* buf, size_t totalLen) const;
158
159    virtual bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) const = 0;
160    virtual std::string CreateFrame(bool isLast, FrameType frameType) const = 0;
161    virtual std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const = 0;
162    virtual std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const = 0;
163    virtual bool DecodeMessage(WebSocketFrame& wsFrame) const = 0;
164
165protected:
166    static constexpr size_t HTTP_HANDSHAKE_MAX_LEN = 1024;
167    static constexpr int SOCKET_SUCCESS = 0;
168
169private:
170    std::atomic<ConnectionState> connectionState_ {ConnectionState::CLOSED};
171
172    mutable std::shared_mutex connectionMutex_;
173    int connectionFd_ {-1};
174
175    // Callbacks used during different stages of connection lifecycle.
176    CloseConnectionCallback closeCb_;
177    FailConnectionCallback failCb_;
178
179    static constexpr std::string_view DECODE_DISCONNECT_MSG = "disconnect";
180};
181} // namespace OHOS::ArkCompiler::Toolchain
182
183#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H
184