1/* 2 * Copyright (c) 2022 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include "common/log_wrapper.h" 17#include "define.h" 18#include "frame_builder.h" 19#include "network.h" 20#include "websocket_base.h" 21 22#include <mutex> 23 24namespace OHOS::ArkCompiler::Toolchain { 25static std::string ToString(CloseStatusCode status) 26{ 27 if (status == CloseStatusCode::NO_STATUS_CODE) { 28 return ""; 29 } 30 std::string result; 31 PushNumberPerByte(result, EnumToNumber(status)); 32 return result; 33} 34 35WebSocketBase::~WebSocketBase() noexcept 36{ 37 if (connectionFd_ != -1) { 38 LOGW("WebSocket connection is closed while destructing the object"); 39 close(connectionFd_); 40 // Reset directly in order to prevent static analyzer warnings. 41 connectionFd_ = -1; 42 } 43} 44 45// if the data is too large, it will be split into multiple frames, the first frame will be marked as 0x0 46// and the last frame will be marked as 0x1. 47// we just add the 'isLast' parameter to indicate whether it is the last frame. 48bool WebSocketBase::SendReply(const std::string& message, FrameType frameType, bool isLast) const 49{ 50 if (connectionState_.load() != ConnectionState::OPEN) { 51 LOGE("SendReply failed, websocket not connected"); 52 return false; 53 } 54 55 const auto frame = CreateFrame(isLast, frameType, message); 56 if (!SendUnderLock(frame)) { 57 LOGE("SendReply: send failed"); 58 return false; 59 } 60 return true; 61} 62 63/** 64 * The wired format of this data transmission section is described in detail through ABNFRFC5234. 65 * When receive the message, we should decode it according the spec. The structure is as follows: 66 * 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 67 * +-+-+-+-+-------+-+-------------+-------------------------------+ 68 * |F|R|R|R| opcode|M| Payload len | Extended payload length | 69 * |I|S|S|S| (4) |A| (7) | (16/64) | 70 * |N|V|V|V| |S| | (if payload len==126/127) | 71 * | |1|2|3| |K| | | 72 * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 73 * | Extended payload length continued, if payload len == 127 | 74 * + - - - - - - - - - - - - - - - +-------------------------------+ 75 * | |Masking-key, if MASK set to 1 | 76 * +-------------------------------+-------------------------------+ 77 * | Masking-key (continued) | Payload Data | 78 * +-------------------------------- - - - - - - - - - - - - - - - + 79 * : Payload Data continued ... : 80 * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + 81 * | Payload Data continued ... | 82 * +---------------------------------------------------------------+ 83 */ 84 85bool WebSocketBase::ReadPayload(WebSocketFrame& wsFrame) const 86{ 87 if (wsFrame.payloadLen == WebSocketFrame::TWO_BYTES_LENTH_ENC) { 88 uint8_t recvbuf[WebSocketFrame::TWO_BYTES_LENTH] = {0}; 89 if (!RecvUnderLock(recvbuf, WebSocketFrame::TWO_BYTES_LENTH)) { 90 LOGE("ReadPayload: Recv payloadLen == 126 failed"); 91 return false; 92 } 93 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::TWO_BYTES_LENTH); 94 } else if (wsFrame.payloadLen == WebSocketFrame::EIGHT_BYTES_LENTH_ENC) { 95 uint8_t recvbuf[WebSocketFrame::EIGHT_BYTES_LENTH] = {0}; 96 if (!RecvUnderLock(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH)) { 97 LOGE("ReadPayload: Recv payloadLen == 127 failed"); 98 return false; 99 } 100 wsFrame.payloadLen = NetToHostLongLong(recvbuf, WebSocketFrame::EIGHT_BYTES_LENTH); 101 } 102 return DecodeMessage(wsFrame); 103} 104 105bool WebSocketBase::HandleDataFrame(WebSocketFrame& wsFrame) const 106{ 107 if (wsFrame.opcode == EnumToNumber(FrameType::TEXT)) { 108 return ReadPayload(wsFrame); 109 } else { 110 LOGW("Received unsupported data frame, opcode = %{public}d", wsFrame.opcode); 111 } 112 return true; 113} 114 115bool WebSocketBase::HandleControlFrame(WebSocketFrame& wsFrame) 116{ 117 if (wsFrame.opcode == EnumToNumber(FrameType::PING)) { 118 // A Pong frame sent in response to a Ping frame must have identical 119 // "Application data" as found in the message body of the Ping frame 120 // being replied to. 121 // https://www.rfc-editor.org/rfc/rfc6455#section-5.5.3 122 if (!ReadPayload(wsFrame)) { 123 LOGE("Failed to read ping frame payload"); 124 return false; 125 } 126 SendPongFrame(wsFrame.payload); 127 } else if (wsFrame.opcode == EnumToNumber(FrameType::CLOSE)) { 128 // might read payload to response by echoing the status code 129 CloseConnection(CloseStatusCode::NO_STATUS_CODE); 130 } 131 return true; 132} 133 134std::string WebSocketBase::Decode() 135{ 136 if (auto state = connectionState_.load(); state != ConnectionState::OPEN) { 137 LOGE("Decode failed: websocket not connected, state = %{public}d", EnumToNumber(state)); 138 return ""; 139 } 140 141 uint8_t recvbuf[WebSocketFrame::HEADER_LEN] = {0}; 142 if (!RecvUnderLock(recvbuf, WebSocketFrame::HEADER_LEN)) { 143 LOGE("Decode failed, client websocket disconnect"); 144 CloseConnection(CloseStatusCode::UNEXPECTED_ERROR); 145 return std::string(DECODE_DISCONNECT_MSG); 146 } 147 WebSocketFrame wsFrame(recvbuf); 148 if (!ValidateIncomingFrame(wsFrame)) { 149 LOGE("Received websocket frame is invalid - header is %02x%02x", recvbuf[0], recvbuf[1]); 150 CloseConnection(CloseStatusCode::PROTOCOL_ERROR); 151 return std::string(DECODE_DISCONNECT_MSG); 152 } 153 154 if (IsControlFrame(wsFrame.opcode)) { 155 if (HandleControlFrame(wsFrame)) { 156 return wsFrame.payload; 157 } 158 } else if (HandleDataFrame(wsFrame)) { 159 return wsFrame.payload; 160 } 161 // Unexpected data, must close the connection. 162 CloseConnection(CloseStatusCode::PROTOCOL_ERROR); 163 return std::string(DECODE_DISCONNECT_MSG); 164} 165 166bool WebSocketBase::IsConnected() const 167{ 168 return connectionState_.load() == ConnectionState::OPEN; 169} 170 171void WebSocketBase::SetCloseConnectionCallback(CloseConnectionCallback cb) 172{ 173 closeCb_ = std::move(cb); 174} 175 176void WebSocketBase::SetFailConnectionCallback(FailConnectionCallback cb) 177{ 178 failCb_ = std::move(cb); 179} 180 181void WebSocketBase::OnConnectionClose(ConnectionCloseReason status) 182{ 183 if (status == ConnectionCloseReason::FAIL) { 184 if (failCb_) { 185 failCb_(); 186 } 187 } else if (status == ConnectionCloseReason::CLOSE) { 188 if (closeCb_) { 189 closeCb_(); 190 } 191 } 192} 193 194void WebSocketBase::CloseConnectionSocket(ConnectionCloseReason status) 195{ 196 OnConnectionClose(status); 197 198 { 199 // Shared lock due to other thread possibly hanging on `recv` with acquired shared lock. 200 std::shared_lock lock(connectionMutex_); 201 int err = ShutdownSocket(connectionFd_); 202 if (err != 0) { 203 LOGW("Failed to shutdown client socket, errno = %{public}d", errno); 204 } 205 } 206 { 207 // Unique lock due to close and write into `connectionFd_`. 208 // Note that `close` must be also done in critical section, 209 // otherwise the other thread can continue using the outdated and possibly reassigned file descriptor. 210 std::unique_lock lock(connectionMutex_); 211 close(connectionFd_); 212 // Reset directly in order to prevent static analyzer warnings. 213 connectionFd_ = -1; 214 } 215 216 auto expected = ConnectionState::CLOSING; 217 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSED)) { 218 LOGE("In connection transition CLOSING->CLOSED got initial state = %{public}d", EnumToNumber(expected)); 219 } 220} 221 222void WebSocketBase::SendPongFrame(std::string payload) const 223{ 224 const auto frame = CreateFrame(true, FrameType::PONG, std::move(payload)); 225 if (!SendUnderLock(frame)) { 226 LOGE("Decode: Send pong frame failed"); 227 } 228} 229 230void WebSocketBase::SendCloseFrame(CloseStatusCode status) const 231{ 232 const auto frame = CreateFrame(true, FrameType::CLOSE, ToString(status)); 233 if (!SendUnderLock(frame)) { 234 LOGE("SendCloseFrame: Send close frame failed"); 235 } 236} 237 238bool WebSocketBase::CloseConnection(CloseStatusCode status) 239{ 240 auto expected = ConnectionState::OPEN; 241 if (!connectionState_.compare_exchange_strong(expected, ConnectionState::CLOSING)) { 242 // Concurrent connection close detected, do nothing. 243 return false; 244 } 245 246 LOGI("Close connection, status = %{public}d", static_cast<int>(status)); 247 SendCloseFrame(status); 248 // can close connection right after sending back close frame. 249 CloseConnectionSocket(ConnectionCloseReason::CLOSE); 250 return true; 251} 252 253int WebSocketBase::GetConnectionSocket() const 254{ 255 return connectionFd_; 256} 257 258void WebSocketBase::SetConnectionSocket(int socketFd) 259{ 260 connectionFd_ = socketFd; 261} 262 263std::shared_mutex &WebSocketBase::GetConnectionMutex() 264{ 265 return connectionMutex_; 266} 267 268WebSocketBase::ConnectionState WebSocketBase::GetConnectionState() const 269{ 270 return connectionState_.load(); 271} 272 273WebSocketBase::ConnectionState WebSocketBase::SetConnectionState(ConnectionState newState) 274{ 275 return connectionState_.exchange(newState); 276} 277 278bool WebSocketBase::CompareExchangeConnectionState(ConnectionState& expected, ConnectionState newState) 279{ 280 return connectionState_.compare_exchange_strong(expected, newState); 281} 282 283bool WebSocketBase::SendUnderLock(const std::string& message) const 284{ 285 std::shared_lock lock(connectionMutex_); 286 return Send(connectionFd_, message, 0); 287} 288 289bool WebSocketBase::SendUnderLock(const char* buf, size_t totalLen) const 290{ 291 std::shared_lock lock(connectionMutex_); 292 return Send(connectionFd_, buf, totalLen, 0); 293} 294 295bool WebSocketBase::RecvUnderLock(std::string& message) const 296{ 297 std::shared_lock lock(connectionMutex_); 298 return Recv(connectionFd_, message, 0); 299} 300 301bool WebSocketBase::RecvUnderLock(uint8_t* buf, size_t totalLen) const 302{ 303 std::shared_lock lock(connectionMutex_); 304 return Recv(connectionFd_, buf, totalLen, 0); 305} 306 307/* static */ 308bool WebSocketBase::IsDecodeDisconnectMsg(const std::string& message) 309{ 310 return message == DECODE_DISCONNECT_MSG; 311} 312 313#if !defined(OHOS_PLATFORM) 314/* static */ 315bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) 316{ 317 if (timeoutLimit > 0) { 318 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0}; 319 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, 320 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { 321 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); 322 return false; 323 } 324 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, 325 reinterpret_cast<char *>(&timeout), sizeof(timeout)) != SOCKET_SUCCESS) { 326 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); 327 return false; 328 } 329 } 330 return true; 331} 332#else 333/* static */ 334bool WebSocketBase::SetWebSocketTimeOut(int32_t fd, uint32_t timeoutLimit) 335{ 336 if (timeoutLimit > 0) { 337 struct timeval timeout = {static_cast<time_t>(timeoutLimit), 0}; 338 if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { 339 LOGE("SetWebSocketTimeOut setsockopt SO_SNDTIMEO failed, errno = %{public}d", errno); 340 return false; 341 } 342 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) != SOCKET_SUCCESS) { 343 LOGE("SetWebSocketTimeOut setsockopt SO_RCVTIMEO failed, errno = %{public}d", errno); 344 return false; 345 } 346 } 347 return true; 348} 349#endif 350 351#if defined(WINDOWS_PLATFORM) 352/* static */ 353int WebSocketBase::ShutdownSocket(int32_t fd) 354{ 355 return shutdown(fd, SD_BOTH); 356} 357#else 358/* static */ 359int WebSocketBase::ShutdownSocket(int32_t fd) 360{ 361 return shutdown(fd, SHUT_RDWR); 362} 363#endif 364} // namespace OHOS::ArkCompiler::Toolchain 365