1 /* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 18 19 #include "event_manager.h" 20 #include "extra_options_base.h" 21 #include "net_address.h" 22 #include "socket_error.h" 23 #include "socket_remote_info.h" 24 #include "socket_state_base.h" 25 #include "tcp_connect_options.h" 26 #include "tcp_extra_options.h" 27 #include "tcp_send_options.h" 28 #include "tls.h" 29 #include "tls_certificate.h" 30 #include "tls_configuration.h" 31 #include "tls_context_server.h" 32 #include "tls_key.h" 33 #include "tls_socket.h" 34 #include <any> 35 #include <condition_variable> 36 #include <cstring> 37 #include <functional> 38 #include <map> 39 #include <poll.h> 40 #include <thread> 41 #include <tuple> 42 #include <unistd.h> 43 #include <vector> 44 45 namespace OHOS { 46 namespace NetStack { 47 namespace TlsSocketServer { 48 constexpr int USER_LIMIT = 10; 49 using OnMessageCallback = 50 std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 51 using OnCloseCallback = std::function<void(const int &socketFd)>; 52 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>; 53 using ListenCallback = std::function<void(int32_t errorNumber)>; 54 class TLSServerSendOptions { 55 public: 56 /** 57 * Set the socket ID to be transmitted 58 * @param socketFd Communication descriptor 59 */ 60 void SetSocket(const int &socketFd); 61 62 /** 63 * Set the data to send 64 * @param data Send data 65 */ 66 void SetSendData(const std::string &data); 67 68 /** 69 * Get the socket ID 70 * @return Gets the communication descriptor 71 */ 72 [[nodiscard]] const int &GetSocket() const; 73 74 /** 75 * Gets the data sent 76 * @return Send data 77 */ 78 [[nodiscard]] const std::string &GetSendData() const; 79 80 private: 81 int socketFd_; 82 std::string data_; 83 }; 84 85 class TLSSocketServer { 86 public: 87 TLSSocketServer(const TLSSocketServer &) = delete; 88 TLSSocketServer(TLSSocketServer &&) = delete; 89 90 TLSSocketServer &operator=(const TLSSocketServer &) = delete; 91 TLSSocketServer &operator=(TLSSocketServer &&) = delete; 92 93 TLSSocketServer() = default; 94 ~TLSSocketServer(); 95 96 /** 97 * Create sockets, bind and listen waiting for clients to connect 98 * @param tlsListenOptions Bind the listening connection configuration 99 * @param callback callback to the caller if bind ok or not 100 */ 101 void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback); 102 103 /** 104 * Send data through an established encrypted connection 105 * @param data data sent over an established encrypted connection 106 * @return whether the data is successfully sent to the server 107 */ 108 bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback); 109 110 /** 111 * Disconnect by releasing the socket when communicating 112 * @param socketFd The socket ID of the client 113 * @param callback callback to the caller 114 */ 115 void Close(const int socketFd, const TlsSocket::CloseCallback &callback); 116 117 /** 118 * Disconnect by releasing the socket when communicating 119 * @param callback callback to the caller 120 */ 121 void Stop(const TlsSocket::CloseCallback &callback); 122 123 /** 124 * Get the peer network address 125 * @param socketFd The socket ID of the client 126 * @param callback callback to the caller 127 */ 128 void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback); 129 130 /** 131 * Get the peer network address 132 * @param socketFd The socket ID of the client 133 * @param callback callback to the caller 134 */ 135 void GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback); 136 137 /** 138 * Get the status of the current socket 139 * @param callback callback to the caller 140 */ 141 void GetState(const TlsSocket::GetStateCallback &callback); 142 143 /** 144 * Gets or sets the options associated with the current socket 145 * @param tcpExtraOptions options associated with the current socket 146 * @param callback callback to the caller 147 */ 148 bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, 149 const TlsSocket::SetExtraOptionsCallback &callback); 150 151 /** 152 * Get a local digital certificate 153 * @param callback callback to the caller 154 */ 155 void GetCertificate(const TlsSocket::GetCertificateCallback &callback); 156 157 /** 158 * Get the peer digital certificate 159 * @param socketFd The socket ID of the client 160 * @param needChain need chain 161 * @param callback callback to the caller 162 */ 163 void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback); 164 165 /** 166 * Obtain the protocol used in communication 167 * @param callback callback to the caller 168 */ 169 void GetProtocol(const TlsSocket::GetProtocolCallback &callback); 170 171 /** 172 * Obtain the cipher suite used in communication 173 * @param socketFd The socket ID of the client 174 * @param callback callback to the caller 175 */ 176 void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback); 177 178 /** 179 * Obtain the encryption algorithm used in the communication process 180 * @param socketFd The socket ID of the client 181 * @param callback callback to the caller 182 */ 183 void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback); 184 185 /** 186 * Register the callback that is called when the connection is disconnected 187 * @param onCloseCallback callback invoked when disconnected 188 */ 189 190 /** 191 * Register the callback that is called when the connection is established 192 * @param onConnectCallback callback invoked when connection is established 193 */ 194 void OnConnect(const OnConnectCallback &onConnectCallback); 195 196 /** 197 * Register the callback that is called when an error occurs 198 * @param onErrorCallback callback invoked when an error occurs 199 */ 200 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 201 202 /** 203 * Off Connect 204 */ 205 void OffConnect(); 206 207 /** 208 * Off Error 209 */ 210 void OffError(); 211 212 /** 213 * Get the socket file description of the server 214 */ 215 int GetListenSocketFd(); 216 217 /** 218 * Set the current socket file description address of the server 219 */ 220 void SetLocalAddress(const Socket::NetAddress &address); 221 222 /** 223 * Get the current socket file description address of the server 224 */ 225 Socket::NetAddress GetLocalAddress(); 226 227 public: 228 class Connection : public std::enable_shared_from_this<Connection> { 229 public: 230 ~Connection(); 231 /** 232 * Establish an encrypted accept on the specified socket 233 * @param sock socket for establishing encrypted connection 234 * @param options some options required during tls accept 235 * @return whether the encrypted accept is successfully established 236 */ 237 bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options); 238 239 /** 240 * Set the configuration items for establishing encrypted connections 241 * @param config configuration item when establishing encrypted connection 242 */ 243 void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 244 245 /** 246 * Set address information 247 */ 248 void SetAddress(const Socket::NetAddress address); 249 250 /** 251 * Set local address information 252 */ 253 void SetLocalAddress(const Socket::NetAddress address); 254 255 /** 256 * Send data through an established encrypted connection 257 * @param data data sent over an established encrypted connection 258 * @return whether the data is successfully sent to the server 259 */ 260 bool Send(const std::string &data); 261 262 /** 263 * Receive the data sent by the server through the established encrypted connection 264 * @param buffer receive the data sent by the server 265 * @param maxBufferSize the size of the data received from the server 266 * @return whether the data sent by the server is successfully received 267 */ 268 int Recv(char *buffer, int maxBufferSize); 269 270 /** 271 * Disconnect encrypted connection 272 * @return whether the encrypted connection was successfully disconnected 273 */ 274 bool Close(); 275 276 /** 277 * Set the application layer negotiation protocol in the encrypted communication process 278 * @param alpnProtocols application layer negotiation protocol 279 * @return set whether the application layer negotiation protocol is successful during encrypted communication 280 */ 281 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 282 283 /** 284 * Storage of server communication related network information 285 * @param remoteInfo communication related network information 286 */ 287 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 288 289 /** 290 * Get configuration options for encrypted communication process 291 * @return configuration options for encrypted communication processes 292 */ 293 [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const; 294 295 /** 296 * Obtain the cipher suite during encrypted communication 297 * @return crypto suite used in encrypted communication 298 */ 299 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 300 301 /** 302 * Obtain the peer certificate used in encrypted communication 303 * @return peer certificate used in encrypted communication 304 */ 305 [[nodiscard]] std::string GetRemoteCertificate() const; 306 307 /** 308 * Obtain the peer certificate used in encrypted communication 309 * @return peer certificate serialization data used in encrypted communication 310 */ 311 [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const; 312 313 /** 314 * Obtain the certificate used in encrypted communication 315 * @return certificate serialization data used in encrypted communication 316 */ 317 [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const; 318 319 /** 320 * Get the encryption algorithm used in encrypted communication 321 * @return encryption algorithm used in encrypted communication 322 */ 323 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 324 325 /** 326 * Obtain the communication protocol used in encrypted communication 327 * @return communication protocol used in encrypted communication 328 */ 329 [[nodiscard]] std::string GetProtocol() const; 330 331 /** 332 * Set the information about the shared signature algorithm supported by peers during encrypted communication 333 * @return information about peer supported shared signature algorithms 334 */ 335 [[nodiscard]] bool SetSharedSigals(); 336 337 /** 338 * Obtain the ssl used in encrypted communication 339 * @return SSL used in encrypted communication 340 */ 341 [[nodiscard]] ssl_st *GetSSL() const; 342 343 /** 344 * Get address information 345 * @return Returns the address information of the remote client 346 */ 347 [[nodiscard]] Socket::NetAddress GetAddress() const; 348 349 /** 350 * Get local address information 351 * @return Returns the address information of the local accept connect 352 */ 353 [[nodiscard]] Socket::NetAddress GetLocalAddress() const; 354 355 /** 356 * Get address information 357 * @return Returns the address information of the remote client 358 */ 359 [[nodiscard]] int GetSocketFd() const; 360 361 /** 362 * Get EventManager information 363 * @return Returns the address information of the remote client 364 */ 365 [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const; 366 367 void OnMessage(const OnMessageCallback &onMessageCallback); 368 /** 369 * Unregister the callback which is called when message is received 370 */ 371 void OffMessage(); 372 373 void CallOnMessageCallback(int32_t socketFd, const std::string &data, 374 const Socket::SocketRemoteInfo &remoteInfo); 375 376 void SetEventManager(std::shared_ptr<EventManager> eventManager); 377 378 void SetClientID(int32_t clientID); 379 380 [[nodiscard]] int GetClientID(); 381 382 void CallOnCloseCallback(const int32_t socketFd); 383 void OnClose(const OnCloseCallback &onCloseCallback); 384 OnCloseCallback onCloseCallback_; 385 386 /** 387 * Off Close 388 */ 389 void OffClose(); 390 391 /** 392 * Register the callback that is called when an error occurs 393 * @param onErrorCallback callback invoked when an error occurs 394 */ 395 void OnError(const TlsSocket::OnErrorCallback &onErrorCallback); 396 /** 397 * Off Error 398 */ 399 void OffError(); 400 401 void CallOnErrorCallback(int32_t err, const std::string &errString); 402 403 TlsSocket::OnErrorCallback onErrorCallback_; 404 405 private: 406 bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options); 407 bool CreatTlsContext(); 408 bool StartShakingHands(const TlsSocket::TLSConnectOptions &options); 409 bool GetRemoteCertificateFromPeer(); 410 bool SetRemoteCertRawData(); 411 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 412 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 413 const X509 *x509Certificates); 414 415 private: 416 ssl_st *ssl_ = nullptr; 417 X509 *peerX509_ = nullptr; 418 int32_t socketFd_ = 0; 419 420 TlsSocket::TLSContextServer tlsContext_; 421 TlsSocket::TLSConfiguration connectionConfiguration_; 422 Socket::NetAddress address_; 423 Socket::NetAddress localAddress_; 424 TlsSocket::X509CertRawData remoteRawData_; 425 426 std::string hostName_; 427 std::string remoteCert_; 428 std::string keyPass_; 429 430 std::vector<std::string> signatureAlgorithms_; 431 std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr; 432 433 std::shared_ptr<EventManager> eventManager_ = nullptr; 434 int32_t clientID_ = 0; 435 OnMessageCallback onMessageCallback_; 436 }; 437 438 private: 439 void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config); 440 int RecvRemoteInfo(int socketFd, int index); 441 void RemoveConnect(int socketFd); 442 void AddConnect(int socketFd, std::shared_ptr<Connection> connection); 443 void CallListenCallback(int32_t err, ListenCallback callback); 444 void CallOnErrorCallback(int32_t err, const std::string &errString); 445 446 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback); 447 void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager); 448 void CallSendCallback(int32_t err, TlsSocket::SendCallback callback); 449 bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback); 450 void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback); 451 void MakeIpSocket(sa_family_t family); 452 void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 453 socklen_t *len); 454 static constexpr const size_t MAX_ERROR_LEN = 128; 455 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 456 457 void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions); 458 459 private: 460 std::mutex mutex_; 461 std::mutex connectMutex_; 462 int listenSocketFd_ = -1; 463 Socket::NetAddress address_; 464 Socket::NetAddress localAddress_; 465 466 std::map<int, std::shared_ptr<Connection>> clientIdConnections_; 467 TlsSocket::TLSConfiguration TLSServerConfiguration_; 468 469 OnConnectCallback onConnectCallback_; 470 TlsSocket::OnErrorCallback onErrorCallback_; 471 472 bool GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress); 473 void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId); 474 void DropFdFromPollList(int &fd_index); 475 void InitPollList(int &listendFd); 476 477 struct pollfd fds_[USER_LIMIT + 1]; 478 479 bool isRunning_; 480 481 public: 482 std::shared_ptr<Connection> GetConnectionByClientID(int clientid); 483 int GetConnectionClientCount(); 484 485 std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager); 486 void CloseConnectionByEventManager(EventManager *eventManager); 487 void DeleteConnectionByEventManager(EventManager *eventManager); 488 void SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID, 489 int connectFD, std::shared_ptr<Connection> &connection); 490 }; 491 } // namespace TlsSocketServer 492 } // namespace NetStack 493 } // namespace OHOS 494 495 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H 496