1 /* 2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H 17 #define COMMUNICATIONNETSTACK_TLS_SOCEKT_H 18 19 #include <any> 20 #include <condition_variable> 21 #include <cstring> 22 #include <functional> 23 #include <map> 24 #include <thread> 25 #include <tuple> 26 #include <unistd.h> 27 #include <vector> 28 29 #include "extra_options_base.h" 30 #include "net_address.h" 31 #include "socket_error.h" 32 #include "socket_remote_info.h" 33 #include "socket_state_base.h" 34 #include "tcp_connect_options.h" 35 #include "tcp_extra_options.h" 36 #include "tcp_send_options.h" 37 #include "tls.h" 38 #include "tls_certificate.h" 39 #include "tls_configuration.h" 40 #include "tls_context.h" 41 #include "tls_key.h" 42 43 namespace OHOS { 44 namespace NetStack { 45 namespace TlsSocket { 46 47 using BindCallback = std::function<void(int32_t errorNumber)>; 48 using ConnectCallback = std::function<void(int32_t errorNumber)>; 49 using SendCallback = std::function<void(int32_t errorNumber)>; 50 using CloseCallback = std::function<void(int32_t errorNumber)>; 51 using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 52 using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>; 53 using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>; 54 using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>; 55 using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 56 using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>; 57 using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>; 58 using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>; 59 using GetSignatureAlgorithmsCallback = 60 std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>; 61 62 using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>; 63 using OnConnectCallback = std::function<void(void)>; 64 using OnCloseCallback = std::function<void(void)>; 65 using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>; 66 67 using CheckServerIdentity = 68 std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>; 69 70 constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1"; 71 constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2"; 72 73 constexpr size_t MAX_ERR_LEN = 1024; 74 75 /** 76 * Parameters required during communication 77 */ 78 class TLSSecureOptions { 79 public: 80 TLSSecureOptions() = default; 81 ~TLSSecureOptions() = default; 82 83 TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions); 84 TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions); 85 /** 86 * Set root CA Chain to verify the server cert 87 * @param caChain root certificate chain used to validate server certificates 88 */ 89 void SetCaChain(const std::vector<std::string> &caChain); 90 91 /** 92 * Set digital certificate for server verification 93 * @param cert digital certificate sent to the server to verify validity 94 */ 95 void SetCert(const std::string &cert); 96 97 /** 98 * Set key to decrypt server data 99 * @param keyChain key used to decrypt server data 100 */ 101 void SetKey(const SecureData &key); 102 103 /** 104 * Set the password to read the private key 105 * @param keyPass read the password of the private key 106 */ 107 void SetKeyPass(const SecureData &keyPass); 108 109 /** 110 * Set the protocol used in communication 111 * @param protocolChain protocol version number used 112 */ 113 void SetProtocolChain(const std::vector<std::string> &protocolChain); 114 115 /** 116 * Whether the peer cipher suite is preferred for communication 117 * @param useRemoteCipherPrefer whether the peer cipher suite is preferred 118 */ 119 void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer); 120 121 /** 122 * Encryption algorithm used in communication 123 * @param signatureAlgorithms encryption algorithm e.g: rsa 124 */ 125 void SetSignatureAlgorithms(const std::string &signatureAlgorithms); 126 127 /** 128 * Crypto suite used in communication 129 * @param cipherSuite cipher suite e.g:AES256-SHA256 130 */ 131 void SetCipherSuite(const std::string &cipherSuite); 132 133 /** 134 * Set a revoked certificate 135 * @param crlChain certificate Revocation List 136 */ 137 void SetCrlChain(const std::vector<std::string> &crlChain); 138 139 /** 140 * Get root CA Chain to verify the server cert 141 * @return root CA chain 142 */ 143 [[nodiscard]] const std::vector<std::string> &GetCaChain() const; 144 145 /** 146 * Obtain a certificate to send to the server for checking 147 * @return digital certificate obtained 148 */ 149 [[nodiscard]] const std::string &GetCert() const; 150 151 /** 152 * Obtain the private key in the communication process 153 * @return private key during communication 154 */ 155 [[nodiscard]] const SecureData &GetKey() const; 156 157 /** 158 * Get the password to read the private key 159 * @return read the password of the private key 160 */ 161 [[nodiscard]] const SecureData &GetKeyPass() const; 162 163 /** 164 * Get the protocol of the communication process 165 * @return protocol of communication process 166 */ 167 [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const; 168 169 /** 170 * Is the remote cipher suite being used for communication 171 * @return is use Remote Cipher Prefer 172 */ 173 [[nodiscard]] bool UseRemoteCipherPrefer() const; 174 175 /** 176 * Obtain the encryption algorithm used in the communication process 177 * @return encryption algorithm used in communication 178 */ 179 [[nodiscard]] const std::string &GetSignatureAlgorithms() const; 180 181 /** 182 * Obtain the cipher suite used in communication 183 * @return crypto suite used in communication 184 */ 185 [[nodiscard]] const std::string &GetCipherSuite() const; 186 187 /** 188 * Get revoked certificate chain 189 * @return revoked certificate chain 190 */ 191 [[nodiscard]] const std::vector<std::string> &GetCrlChain() const; 192 193 void SetVerifyMode(VerifyMode verifyMode); 194 195 [[nodiscard]] VerifyMode GetVerifyMode() const; 196 197 private: 198 std::vector<std::string> caChain_; 199 std::string cert_; 200 SecureData key_; 201 SecureData keyPass_; 202 std::vector<std::string> protocolChain_; 203 bool useRemoteCipherPrefer_ = false; 204 std::string signatureAlgorithms_; 205 std::string cipherSuite_; 206 std::vector<std::string> crlChain_; 207 VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE; 208 }; 209 210 /** 211 * Some options required during tls connection 212 */ 213 class TLSConnectOptions { 214 public: 215 friend class TLSSocketExec; 216 /** 217 * Communication parameters required for connection establishment 218 * @param address communication parameters during connection 219 */ 220 void SetNetAddress(const Socket::NetAddress &address); 221 222 /** 223 * Parameters required during communication 224 * @param tlsSecureOptions certificate and other relevant parameters 225 */ 226 void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions); 227 228 /** 229 * Set the callback function to check the validity of the server 230 * @param checkServerIdentity callback function passed in by API caller 231 */ 232 void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity); 233 234 /** 235 * Set application layer protocol negotiation 236 * @param alpnProtocols application layer protocol negotiation 237 */ 238 void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 239 240 /** 241 * Set whether to skip remote validation 242 * @param skipRemoteValidation flag to choose whether to skip validation 243 */ 244 void SetSkipRemoteValidation(bool skipRemoteValidation); 245 246 /** 247 * Obtain the network address of the communication process 248 * @return network address 249 */ 250 [[nodiscard]] Socket::NetAddress GetNetAddress() const; 251 252 /** 253 * Obtain the parameters required in the communication process 254 * @return certificate and other relevant parameters 255 */ 256 [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const; 257 258 /** 259 * Get the check server ID callback function passed in by the API caller 260 * @return check the server identity callback function 261 */ 262 [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const; 263 264 /** 265 * Obtain the application layer protocol negotiation in the communication process 266 * @return application layer protocol negotiation 267 */ 268 [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const; 269 270 /** 271 * Get the choice of whether to skip remote validaion 272 * @return skipRemoteValidaion result 273 */ 274 [[nodiscard]] bool GetSkipRemoteValidation() const; 275 276 private: 277 Socket::NetAddress address_; 278 TLSSecureOptions tlsSecureOptions_; 279 CheckServerIdentity checkServerIdentity_; 280 std::vector<std::string> alpnProtocols_; 281 bool skipRemoteValidation_ = false; 282 }; 283 284 /** 285 * TLS socket interface class 286 */ 287 class TLSSocket { 288 public: 289 TLSSocket(const TLSSocket &) = delete; 290 TLSSocket(TLSSocket &&) = delete; 291 292 TLSSocket &operator=(const TLSSocket &) = delete; 293 TLSSocket &operator=(TLSSocket &&) = delete; 294 295 TLSSocket() = default; 296 ~TLSSocket() = default; 297 TLSSocket(int sockFd)298 explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {} 299 300 /** 301 * Create a socket and bind to the address specified by address 302 * @param address ip address 303 * @param callback callback to the caller if bind ok or not 304 */ 305 void Bind(Socket::NetAddress &address, const BindCallback &callback); 306 307 /** 308 * Establish a secure connection based on the created socket 309 * @param tlsConnectOptions some options required during tls connection 310 * @param callback callback to the caller if connect ok or not 311 */ 312 void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback); 313 314 /** 315 * Send data based on the created socket 316 * @param tcpSendOptions some options required during tcp data transmission 317 * @param callback callback to the caller if send ok or not 318 */ 319 void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback); 320 321 /** 322 * Disconnect by releasing the socket when communicating 323 * @param callback callback to the caller 324 */ 325 void Close(const CloseCallback &callback); 326 327 /** 328 * Get the peer network address 329 * @param callback callback to the caller 330 */ 331 void GetRemoteAddress(const GetRemoteAddressCallback &callback); 332 333 /** 334 * Get the status of the current socket 335 * @param callback callback to the caller 336 */ 337 void GetState(const GetStateCallback &callback); 338 339 /** 340 * Gets or sets the options associated with the current socket 341 * @param tcpExtraOptions options associated with the current socket 342 * @param callback callback to the caller 343 */ 344 void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback); 345 346 /** 347 * Get a local digital certificate 348 * @param callback callback to the caller 349 */ 350 void GetCertificate(const GetCertificateCallback &callback); 351 352 /** 353 * Get the peer digital certificate 354 * @param needChain need chain 355 * @param callback callback to the caller 356 */ 357 void GetRemoteCertificate(const GetRemoteCertificateCallback &callback); 358 359 /** 360 * Obtain the protocol used in communication 361 * @param callback callback to the caller 362 */ 363 void GetProtocol(const GetProtocolCallback &callback); 364 365 /** 366 * Obtain the cipher suite used in communication 367 * @param callback callback to the caller 368 */ 369 void GetCipherSuite(const GetCipherSuiteCallback &callback); 370 371 /** 372 * Obtain the encryption algorithm used in the communication process 373 * @param callback callback to the caller 374 */ 375 void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback); 376 377 /** 378 * Register a callback which is called when message is received 379 * @param onMessageCallback callback which is called when message is received 380 */ 381 void OnMessage(const OnMessageCallback &onMessageCallback); 382 383 /** 384 * Register the callback that is called when the connection is established 385 * @param onConnectCallback callback invoked when connection is established 386 */ 387 void OnConnect(const OnConnectCallback &onConnectCallback); 388 389 /** 390 * Register the callback that is called when the connection is disconnected 391 * @param onCloseCallback callback invoked when disconnected 392 */ 393 void OnClose(const OnCloseCallback &onCloseCallback); 394 395 /** 396 * Register the callback that is called when an error occurs 397 * @param onErrorCallback callback invoked when an error occurs 398 */ 399 void OnError(const OnErrorCallback &onErrorCallback); 400 401 /** 402 * Unregister the callback which is called when message is received 403 */ 404 void OffMessage(); 405 406 /** 407 * Off Connect 408 */ 409 void OffConnect(); 410 411 /** 412 * Off Close 413 */ 414 void OffClose(); 415 416 /** 417 * Off Error 418 */ 419 void OffError(); 420 421 /** 422 * Get the socket file description of the server 423 */ 424 int GetSocketFd(); 425 426 /** 427 * Set the current socket file description address of the server 428 */ 429 void SetLocalAddress(const Socket::NetAddress &address); 430 431 /** 432 * Get the current socket file description address of the server 433 */ 434 Socket::NetAddress GetLocalAddress(); 435 436 bool GetCloseState(); 437 438 void SetCloseState(bool flag); 439 440 std::mutex &GetCloseLock(); 441 private: 442 class TLSSocketInternal final { 443 public: 444 TLSSocketInternal() = default; 445 ~TLSSocketInternal() = default; 446 447 /** 448 * Establish an encrypted connection on the specified socket 449 * @param sock socket for establishing encrypted connection 450 * @param options some options required during tls connection 451 * @param isExtSock socket fd is originated from external source when constructing tls socket 452 * @return whether the encrypted connection is successfully established 453 */ 454 bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock); 455 456 /** 457 * Set the configuration items for establishing encrypted connections 458 * @param config configuration item when establishing encrypted connection 459 */ 460 void SetTlsConfiguration(const TLSConnectOptions &config); 461 462 /** 463 * Send data through an established encrypted connection 464 * @param data data sent over an established encrypted connection 465 * @return whether the data is successfully sent to the server 466 */ 467 bool Send(const std::string &data); 468 469 /** 470 * Receive the data sent by the server through the established encrypted connection 471 * @param buffer receive the data sent by the server 472 * @param maxBufferSize the size of the data received from the server 473 * @return whether the data sent by the server is successfully received 474 */ 475 int Recv(char *buffer, int maxBufferSize); 476 477 /** 478 * Disconnect encrypted connection 479 * @return whether the encrypted connection was successfully disconnected 480 */ 481 bool Close(); 482 483 /** 484 * Set the application layer negotiation protocol in the encrypted communication process 485 * @param alpnProtocols application layer negotiation protocol 486 * @return set whether the application layer negotiation protocol is successful during encrypted communication 487 */ 488 bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols); 489 490 /** 491 * Storage of server communication related network information 492 * @param remoteInfo communication related network information 493 */ 494 void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo); 495 496 /** 497 * convert the code to ssl error code 498 * @return the value for ssl error code. 499 */ 500 int ConvertSSLError(void); 501 502 /** 503 * Get configuration options for encrypted communication process 504 * @return configuration options for encrypted communication processes 505 */ 506 [[nodiscard]] TLSConfiguration GetTlsConfiguration() const; 507 508 /** 509 * Obtain the cipher suite during encrypted communication 510 * @return crypto suite used in encrypted communication 511 */ 512 [[nodiscard]] std::vector<std::string> GetCipherSuite() const; 513 514 /** 515 * Obtain the peer certificate used in encrypted communication 516 * @return peer certificate used in encrypted communication 517 */ 518 [[nodiscard]] std::string GetRemoteCertificate() const; 519 520 /** 521 * Obtain the peer certificate used in encrypted communication 522 * @return peer certificate serialization data used in encrypted communication 523 */ 524 [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const; 525 526 /** 527 * Obtain the certificate used in encrypted communication 528 * @return certificate serialization data used in encrypted communication 529 */ 530 [[nodiscard]] const X509CertRawData &GetCertificate() const; 531 532 /** 533 * Get the encryption algorithm used in encrypted communication 534 * @return encryption algorithm used in encrypted communication 535 */ 536 [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const; 537 538 /** 539 * Obtain the communication protocol used in encrypted communication 540 * @return communication protocol used in encrypted communication 541 */ 542 [[nodiscard]] std::string GetProtocol() const; 543 544 /** 545 * Set the information about the shared signature algorithm supported by peers during encrypted communication 546 * @return information about peer supported shared signature algorithms 547 */ 548 [[nodiscard]] bool SetSharedSigals(); 549 550 /** 551 * Obtain the ssl used in encrypted communication 552 * @return SSL used in encrypted communication 553 */ 554 [[nodiscard]] ssl_st *GetSSL(); 555 556 private: 557 bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd); 558 bool StartTlsConnected(const TLSConnectOptions &options); 559 bool CreatTlsContext(); 560 bool StartShakingHands(const TLSConnectOptions &options); 561 bool GetRemoteCertificateFromPeer(); 562 bool SetRemoteCertRawData(); 563 bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize); 564 std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates); 565 std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext, 566 const X509 *x509Certificates); 567 568 private: 569 std::mutex mutexForSsl_; 570 ssl_st *ssl_ = nullptr; 571 X509 *peerX509_ = nullptr; 572 uint16_t port_ = 0; 573 sa_family_t family_ = 0; 574 int32_t socketDescriptor_ = 0; 575 576 TLSContext tlsContext_; 577 TLSConfiguration configuration_; 578 Socket::NetAddress address_; 579 X509CertRawData remoteRawData_; 580 581 std::string hostName_; 582 std::string remoteCert_; 583 584 std::vector<std::string> signatureAlgorithms_; 585 std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr; 586 }; 587 588 private: 589 TLSSocketInternal tlsSocketInternal_; 590 591 static std::string MakeAddressString(sockaddr *addr); 592 593 static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr, 594 socklen_t *len); 595 596 void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo); 597 void CallOnConnectCallback(); 598 void CallOnCloseCallback(); 599 void CallOnErrorCallback(int32_t err, const std::string &errString); 600 601 void CallBindCallback(int32_t err, BindCallback callback); 602 void CallConnectCallback(int32_t err, ConnectCallback callback); 603 void CallSendCallback(int32_t err, SendCallback callback); 604 void CallCloseCallback(int32_t err, CloseCallback callback); 605 void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address, 606 GetRemoteAddressCallback callback); 607 void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback); 608 void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback); 609 void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback); 610 void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert, 611 GetRemoteCertificateCallback callback); 612 void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback); 613 void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite, 614 GetCipherSuiteCallback callback); 615 void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms, 616 GetSignatureAlgorithmsCallback callback); 617 618 int ReadMessage(); 619 void StartReadMessage(); 620 621 void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback); 622 void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback); 623 624 [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const; 625 [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const; 626 627 void MakeIpSocket(sa_family_t family); 628 629 template<class T> DealCallback(int32_t err, T &callback)630 void DealCallback(int32_t err, T &callback) 631 { 632 T func = nullptr; 633 { 634 std::lock_guard<std::mutex> lock(mutex_); 635 if (callback) { 636 func = callback; 637 } 638 } 639 640 if (func) { 641 func(err); 642 } 643 } 644 645 private: 646 static constexpr const size_t MAX_ERROR_LEN = 128; 647 static constexpr const size_t MAX_BUFFER_SIZE = 8192; 648 649 OnMessageCallback onMessageCallback_; 650 OnConnectCallback onConnectCallback_; 651 OnCloseCallback onCloseCallback_; 652 OnErrorCallback onErrorCallback_; 653 654 std::mutex mutex_; 655 std::mutex recvMutex_; 656 std::mutex cvMutex_; 657 bool isRunning_ = false; 658 bool isRunOver_ = true; 659 std::condition_variable cvSslFree_; 660 int sockFd_ = -1; 661 bool isExtSock_ = false; 662 Socket::NetAddress localAddress_; 663 bool isClosed = false; 664 std::mutex mutexForClose_; 665 }; 666 } // namespace TlsSocket 667 } // namespace NetStack 668 } // namespace OHOS 669 670 #endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H 671