1/* 2 * Copyright (c) 2022 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#ifndef ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H 17#define ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H 18 19#include "web_socket_frame.h" 20 21#include <atomic> 22#include <functional> 23#include <shared_mutex> 24#include <type_traits> 25 26namespace OHOS::ArkCompiler::Toolchain { 27enum CloseStatusCode : uint16_t { 28 NO_STATUS_CODE = 0, 29 NORMAL = 1000, 30 SERVER_GO_AWAY = 1001, 31 PROTOCOL_ERROR = 1002, 32 UNACCEPTABLE_DATA = 1003, 33 INCONSISTENT_DATA = 1007, 34 POLICY_VIOLATION = 1008, 35 MESSAGE_TOO_BIG = 1009, 36 UNEXPECTED_ERROR = 1011, 37}; 38 39class WebSocketBase { 40public: 41 using CloseConnectionCallback = std::function<void()>; 42 using FailConnectionCallback = std::function<void()>; 43 44public: 45 static bool IsDecodeDisconnectMsg(const std::string& message); 46 47 WebSocketBase() = default; 48 virtual ~WebSocketBase() noexcept; 49 50 /** 51 * @brief Receive and decode a message. 52 * Must not be called concurrently on the same connection. 53 * Safe to call concurrently with `SendReply` and `Close`. 54 * Control frames are handled according to specification with an empty string as returned value, 55 * otherwise the method returns the decoded received message. 56 * Note that this method closes the connection after receiving invalid data. 57 * This event can be checked with `IsDecodeDisconnectMsg`. 58 */ 59 std::string Decode(); 60 61 /** 62 * @brief Send message on current connection. 63 * Safe to call concurrently with: `SendReply`, `Decode`, `Close`. 64 * Note that the connection is not closed on transmission failures. 65 * @param message text payload. 66 * @param frameType frame type, must be either TEXT, BINARY or CONTINUATION. 67 * @param isLast flag indicating whether the message is the final. 68 * @returns true on success, false otherwise. 69 */ 70 bool SendReply(const std::string& message, FrameType frameType = FrameType::TEXT, bool isLast = true) const; 71 72 /** 73 * @brief Check if connection is in `OPEN` state. 74 */ 75 bool IsConnected() const; 76 77 /** 78 * @brief Set callback for calling after normal connection close. 79 * Non thread safe. 80 */ 81 void SetCloseConnectionCallback(CloseConnectionCallback cb); 82 83 /** 84 * @brief Set callback for calling after closing connection on any failure. 85 * Non thread safe. 86 */ 87 void SetFailConnectionCallback(FailConnectionCallback cb); 88 89 /** 90 * @brief Send `CLOSE` frame and close the connection socket. 91 * Does nothing if connection was not in `OPEN` state. 92 * @param status close status code specified in sent frame. 93 * @returns true if connection was closed, false otherwise. 94 */ 95 bool CloseConnection(CloseStatusCode status); 96 97protected: 98 enum class ConnectionState : uint8_t { 99 CONNECTING, 100 OPEN, 101 CLOSING, 102 CLOSED, 103 }; 104 105 enum class ConnectionCloseReason: uint8_t { 106 FAIL, 107 CLOSE, 108 }; 109 110protected: 111 /** 112 * @brief Set `send` and `recv` timeout limits. 113 * @param fd socket to set timeout on. 114 * @param timeoutLimit timeout in seconds. If zero, function is no-op. 115 * @returns true on success, false otherwise. 116 */ 117 static bool SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit); 118 119 /** 120 * @brief Shutdown socket for sends and receives. 121 * Note that the implementation of this function is platform-specific, 122 * so there is no unified way to retrieve error code returned from system call. 123 * @param fd socket file descriptor. 124 * @returns zero on success, `-1` otherwise. 125 */ 126 static int ShutdownSocket(int32_t fd); 127 128 /** 129 * @brief Close the connection socket. 130 * Must be transition from `CLOSING` to `CLOSED` connection state. 131 * @param status close reason, depends which callback to execute. 132 */ 133 void CloseConnectionSocket(ConnectionCloseReason status); 134 135 /** 136 * @brief Execute user-provided callbacks before closing the connection socket. 137 */ 138 void OnConnectionClose(ConnectionCloseReason status); 139 140 int GetConnectionSocket() const; 141 void SetConnectionSocket(int socketFd); 142 std::shared_mutex &GetConnectionMutex(); 143 144 ConnectionState GetConnectionState() const; 145 ConnectionState SetConnectionState(ConnectionState newState); 146 bool CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState); 147 148 bool HandleDataFrame(WebSocketFrame& wsFrame) const; 149 bool HandleControlFrame(WebSocketFrame& wsFrame); 150 bool ReadPayload(WebSocketFrame& wsFrame) const; 151 void SendPongFrame(std::string payload) const; 152 void SendCloseFrame(CloseStatusCode status) const; 153 154 bool SendUnderLock(const std::string& message) const; 155 bool SendUnderLock(const char* buf, size_t totalLen) const; 156 bool RecvUnderLock(std::string& message) const; 157 bool RecvUnderLock(uint8_t* buf, size_t totalLen) const; 158 159 virtual bool ValidateIncomingFrame(const WebSocketFrame& wsFrame) const = 0; 160 virtual std::string CreateFrame(bool isLast, FrameType frameType) const = 0; 161 virtual std::string CreateFrame(bool isLast, FrameType frameType, const std::string& payload) const = 0; 162 virtual std::string CreateFrame(bool isLast, FrameType frameType, std::string&& payload) const = 0; 163 virtual bool DecodeMessage(WebSocketFrame& wsFrame) const = 0; 164 165protected: 166 static constexpr size_t HTTP_HANDSHAKE_MAX_LEN = 1024; 167 static constexpr int SOCKET_SUCCESS = 0; 168 169private: 170 std::atomic<ConnectionState> connectionState_ {ConnectionState::CLOSED}; 171 172 mutable std::shared_mutex connectionMutex_; 173 int connectionFd_ {-1}; 174 175 // Callbacks used during different stages of connection lifecycle. 176 CloseConnectionCallback closeCb_; 177 FailConnectionCallback failCb_; 178 179 static constexpr std::string_view DECODE_DISCONNECT_MSG = "disconnect"; 180}; 181} // namespace OHOS::ArkCompiler::Toolchain 182 183#endif // ARKCOMPILER_TOOLCHAIN_WEBSOCKET_WEBSOCKET_BASE_H 184