1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef COMMUNICATIONNETSTACK_TLS_SOCEKT_H
17#define COMMUNICATIONNETSTACK_TLS_SOCEKT_H
18
19#include <any>
20#include <condition_variable>
21#include <cstring>
22#include <functional>
23#include <map>
24#include <thread>
25#include <tuple>
26#include <unistd.h>
27#include <vector>
28
29#include "extra_options_base.h"
30#include "net_address.h"
31#include "socket_error.h"
32#include "socket_remote_info.h"
33#include "socket_state_base.h"
34#include "tcp_connect_options.h"
35#include "tcp_extra_options.h"
36#include "tcp_send_options.h"
37#include "tls.h"
38#include "tls_certificate.h"
39#include "tls_configuration.h"
40#include "tls_context.h"
41#include "tls_key.h"
42
43namespace OHOS {
44namespace NetStack {
45namespace TlsSocket {
46
47using BindCallback = std::function<void(int32_t errorNumber)>;
48using ConnectCallback = std::function<void(int32_t errorNumber)>;
49using SendCallback = std::function<void(int32_t errorNumber)>;
50using CloseCallback = std::function<void(int32_t errorNumber)>;
51using GetRemoteAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
52using GetLocalAddressCallback = std::function<void(int32_t errorNumber, const Socket::NetAddress &address)>;
53using GetStateCallback = std::function<void(int32_t errorNumber, const Socket::SocketStateBase &state)>;
54using SetExtraOptionsCallback = std::function<void(int32_t errorNumber)>;
55using GetCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
56using GetRemoteCertificateCallback = std::function<void(int32_t errorNumber, const X509CertRawData &cert)>;
57using GetProtocolCallback = std::function<void(int32_t errorNumber, const std::string &protocol)>;
58using GetCipherSuiteCallback = std::function<void(int32_t errorNumber, const std::vector<std::string> &suite)>;
59using GetSignatureAlgorithmsCallback =
60    std::function<void(int32_t errorNumber, const std::vector<std::string> &algorithms)>;
61
62using OnMessageCallback = std::function<void(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
63using OnConnectCallback = std::function<void(void)>;
64using OnCloseCallback = std::function<void(void)>;
65using OnErrorCallback = std::function<void(int32_t errorNumber, const std::string &errorString)>;
66
67using CheckServerIdentity =
68    std::function<void(const std::string &hostName, const std::vector<std::string> &x509Certificates)>;
69
70constexpr const char *ALPN_PROTOCOLS_HTTP_1_1 = "http1.1";
71constexpr const char *ALPN_PROTOCOLS_HTTP_2 = "h2";
72
73constexpr size_t MAX_ERR_LEN = 1024;
74
75/**
76 * Parameters required during communication
77 */
78class TLSSecureOptions {
79public:
80    TLSSecureOptions() = default;
81    ~TLSSecureOptions() = default;
82
83    TLSSecureOptions(const TLSSecureOptions &tlsSecureOptions);
84    TLSSecureOptions &operator=(const TLSSecureOptions &tlsSecureOptions);
85    /**
86     * Set root CA Chain to verify the server cert
87     * @param caChain root certificate chain used to validate server certificates
88     */
89    void SetCaChain(const std::vector<std::string> &caChain);
90
91    /**
92     * Set digital certificate for server verification
93     * @param cert digital certificate sent to the server to verify validity
94     */
95    void SetCert(const std::string &cert);
96
97    /**
98     * Set key to decrypt server data
99     * @param keyChain key used to decrypt server data
100     */
101    void SetKey(const SecureData &key);
102
103    /**
104     * Set the password to read the private key
105     * @param keyPass read the password of the private key
106     */
107    void SetKeyPass(const SecureData &keyPass);
108
109    /**
110     * Set the protocol used in communication
111     * @param protocolChain protocol version number used
112     */
113    void SetProtocolChain(const std::vector<std::string> &protocolChain);
114
115    /**
116     * Whether the peer cipher suite is preferred for communication
117     * @param useRemoteCipherPrefer whether the peer cipher suite is preferred
118     */
119    void SetUseRemoteCipherPrefer(bool useRemoteCipherPrefer);
120
121    /**
122     * Encryption algorithm used in communication
123     * @param signatureAlgorithms encryption algorithm e.g: rsa
124     */
125    void SetSignatureAlgorithms(const std::string &signatureAlgorithms);
126
127    /**
128     * Crypto suite used in communication
129     * @param cipherSuite cipher suite e.g:AES256-SHA256
130     */
131    void SetCipherSuite(const std::string &cipherSuite);
132
133    /**
134     * Set a revoked certificate
135     * @param crlChain certificate Revocation List
136     */
137    void SetCrlChain(const std::vector<std::string> &crlChain);
138
139    /**
140     * Get root CA Chain to verify the server cert
141     * @return root CA chain
142     */
143    [[nodiscard]] const std::vector<std::string> &GetCaChain() const;
144
145    /**
146     * Obtain a certificate to send to the server for checking
147     * @return digital certificate obtained
148     */
149    [[nodiscard]] const std::string &GetCert() const;
150
151    /**
152     * Obtain the private key in the communication process
153     * @return private key during communication
154     */
155    [[nodiscard]] const SecureData &GetKey() const;
156
157    /**
158     * Get the password to read the private key
159     * @return read the password of the private key
160     */
161    [[nodiscard]] const SecureData &GetKeyPass() const;
162
163    /**
164     * Get the protocol of the communication process
165     * @return protocol of communication process
166     */
167    [[nodiscard]] const std::vector<std::string> &GetProtocolChain() const;
168
169    /**
170     * Is the remote cipher suite being used for communication
171     * @return is use Remote Cipher Prefer
172     */
173    [[nodiscard]] bool UseRemoteCipherPrefer() const;
174
175    /**
176     * Obtain the encryption algorithm used in the communication process
177     * @return encryption algorithm used in communication
178     */
179    [[nodiscard]] const std::string &GetSignatureAlgorithms() const;
180
181    /**
182     * Obtain the cipher suite used in communication
183     * @return crypto suite used in communication
184     */
185    [[nodiscard]] const std::string &GetCipherSuite() const;
186
187    /**
188     * Get revoked certificate chain
189     * @return revoked certificate chain
190     */
191    [[nodiscard]] const std::vector<std::string> &GetCrlChain() const;
192
193    void SetVerifyMode(VerifyMode verifyMode);
194
195    [[nodiscard]] VerifyMode GetVerifyMode() const;
196
197private:
198    std::vector<std::string> caChain_;
199    std::string cert_;
200    SecureData key_;
201    SecureData keyPass_;
202    std::vector<std::string> protocolChain_;
203    bool useRemoteCipherPrefer_ = false;
204    std::string signatureAlgorithms_;
205    std::string cipherSuite_;
206    std::vector<std::string> crlChain_;
207    VerifyMode TLSVerifyMode_ = VerifyMode::ONE_WAY_MODE;
208};
209
210/**
211 * Some options required during tls connection
212 */
213class TLSConnectOptions {
214public:
215    friend class TLSSocketExec;
216    /**
217     * Communication parameters required for connection establishment
218     * @param address communication parameters during connection
219     */
220    void SetNetAddress(const Socket::NetAddress &address);
221
222    /**
223     * Parameters required during communication
224     * @param tlsSecureOptions certificate and other relevant parameters
225     */
226    void SetTlsSecureOptions(TLSSecureOptions &tlsSecureOptions);
227
228    /**
229     * Set the callback function to check the validity of the server
230     * @param checkServerIdentity callback function passed in by API caller
231     */
232    void SetCheckServerIdentity(const CheckServerIdentity &checkServerIdentity);
233
234    /**
235     * Set application layer protocol negotiation
236     * @param alpnProtocols application layer protocol negotiation
237     */
238    void SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
239
240    /**
241     * Set whether to skip remote validation
242     * @param skipRemoteValidation flag to choose whether to skip validation
243     */
244    void SetSkipRemoteValidation(bool skipRemoteValidation);
245
246    /**
247     * Obtain the network address of the communication process
248     * @return network address
249     */
250    [[nodiscard]] Socket::NetAddress GetNetAddress() const;
251
252    /**
253     * Obtain the parameters required in the communication process
254     * @return certificate and other relevant parameters
255     */
256    [[nodiscard]] TLSSecureOptions GetTlsSecureOptions() const;
257
258    /**
259     * Get the check server ID callback function passed in by the API caller
260     * @return check the server identity callback function
261     */
262    [[nodiscard]] CheckServerIdentity GetCheckServerIdentity() const;
263
264    /**
265     * Obtain the application layer protocol negotiation in the communication process
266     * @return application layer protocol negotiation
267     */
268    [[nodiscard]] const std::vector<std::string> &GetAlpnProtocols() const;
269
270    /**
271     * Get the choice of whether to skip remote validaion
272     * @return skipRemoteValidaion result
273     */
274    [[nodiscard]] bool GetSkipRemoteValidation() const;
275
276private:
277    Socket::NetAddress address_;
278    TLSSecureOptions tlsSecureOptions_;
279    CheckServerIdentity checkServerIdentity_;
280    std::vector<std::string> alpnProtocols_;
281    bool skipRemoteValidation_ = false;
282};
283
284/**
285 * TLS socket interface class
286 */
287class TLSSocket {
288public:
289    TLSSocket(const TLSSocket &) = delete;
290    TLSSocket(TLSSocket &&) = delete;
291
292    TLSSocket &operator=(const TLSSocket &) = delete;
293    TLSSocket &operator=(TLSSocket &&) = delete;
294
295    TLSSocket() = default;
296    ~TLSSocket() = default;
297
298    explicit TLSSocket(int sockFd): sockFd_(sockFd), isExtSock_(true) {}
299
300    /**
301     * Create a socket and bind to the address specified by address
302     * @param address ip address
303     * @param callback callback to the caller if bind ok or not
304     */
305    void Bind(Socket::NetAddress &address, const BindCallback &callback);
306
307    /**
308     * Establish a secure connection based on the created socket
309     * @param tlsConnectOptions some options required during tls connection
310     * @param callback callback to the caller if connect ok or not
311     */
312    void Connect(TLSConnectOptions &tlsConnectOptions, const ConnectCallback &callback);
313
314    /**
315     * Send data based on the created socket
316     * @param tcpSendOptions  some options required during tcp data transmission
317     * @param callback callback to the caller if send ok or not
318     */
319    void Send(const Socket::TCPSendOptions &tcpSendOptions, const SendCallback &callback);
320
321    /**
322     * Disconnect by releasing the socket when communicating
323     * @param callback callback to the caller
324     */
325    void Close(const CloseCallback &callback);
326
327    /**
328     * Get the peer network address
329     * @param callback callback to the caller
330     */
331    void GetRemoteAddress(const GetRemoteAddressCallback &callback);
332
333    /**
334     * Get the status of the current socket
335     * @param callback callback to the caller
336     */
337    void GetState(const GetStateCallback &callback);
338
339    /**
340     * Gets or sets the options associated with the current socket
341     * @param tcpExtraOptions options associated with the current socket
342     * @param callback callback to the caller
343     */
344    void SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions, const SetExtraOptionsCallback &callback);
345
346    /**
347     *  Get a local digital certificate
348     * @param callback callback to the caller
349     */
350    void GetCertificate(const GetCertificateCallback &callback);
351
352    /**
353     * Get the peer digital certificate
354     * @param needChain need chain
355     * @param callback callback to the caller
356     */
357    void GetRemoteCertificate(const GetRemoteCertificateCallback &callback);
358
359    /**
360     * Obtain the protocol used in communication
361     * @param callback callback to the caller
362     */
363    void GetProtocol(const GetProtocolCallback &callback);
364
365    /**
366     * Obtain the cipher suite used in communication
367     * @param callback callback to the caller
368     */
369    void GetCipherSuite(const GetCipherSuiteCallback &callback);
370
371    /**
372     * Obtain the encryption algorithm used in the communication process
373     * @param callback callback to the caller
374     */
375    void GetSignatureAlgorithms(const GetSignatureAlgorithmsCallback &callback);
376
377    /**
378     * Register a callback which is called when message is received
379     * @param onMessageCallback callback which is called when message is received
380     */
381    void OnMessage(const OnMessageCallback &onMessageCallback);
382
383    /**
384     * Register the callback that is called when the connection is established
385     * @param onConnectCallback callback invoked when connection is established
386     */
387    void OnConnect(const OnConnectCallback &onConnectCallback);
388
389    /**
390     * Register the callback that is called when the connection is disconnected
391     * @param onCloseCallback callback invoked when disconnected
392     */
393    void OnClose(const OnCloseCallback &onCloseCallback);
394
395    /**
396     * Register the callback that is called when an error occurs
397     * @param onErrorCallback callback invoked when an error occurs
398     */
399    void OnError(const OnErrorCallback &onErrorCallback);
400
401    /**
402     * Unregister the callback which is called when message is received
403     */
404    void OffMessage();
405
406    /**
407     * Off Connect
408     */
409    void OffConnect();
410
411    /**
412     * Off Close
413     */
414    void OffClose();
415
416    /**
417     * Off Error
418     */
419    void OffError();
420
421    /**
422     * Get the socket file description of the server
423     */
424    int GetSocketFd();
425
426    /**
427     * Set the current socket file description address of the server
428     */
429    void SetLocalAddress(const Socket::NetAddress &address);
430
431    /**
432     * Get the current socket file description address of the server
433     */
434    Socket::NetAddress GetLocalAddress();
435
436    bool GetCloseState();
437
438    void SetCloseState(bool flag);
439
440    std::mutex &GetCloseLock();
441private:
442    class TLSSocketInternal final {
443    public:
444        TLSSocketInternal() = default;
445        ~TLSSocketInternal() = default;
446
447        /**
448         * Establish an encrypted connection on the specified socket
449         * @param sock socket for establishing encrypted connection
450         * @param options some options required during tls connection
451         * @param isExtSock socket fd is originated from external source when constructing tls socket
452         * @return whether the encrypted connection is successfully established
453         */
454        bool TlsConnectToHost(int sock, const TLSConnectOptions &options, bool isExtSock);
455
456        /**
457         * Set the configuration items for establishing encrypted connections
458         * @param config configuration item when establishing encrypted connection
459         */
460        void SetTlsConfiguration(const TLSConnectOptions &config);
461
462        /**
463         * Send data through an established encrypted connection
464         * @param data data sent over an established encrypted connection
465         * @return whether the data is successfully sent to the server
466         */
467        bool Send(const std::string &data);
468
469        /**
470         * Receive the data sent by the server through the established encrypted connection
471         * @param buffer receive the data sent by the server
472         * @param maxBufferSize the size of the data received from the server
473         * @return whether the data sent by the server is successfully received
474         */
475        int Recv(char *buffer, int maxBufferSize);
476
477        /**
478         * Disconnect encrypted connection
479         * @return whether the encrypted connection was successfully disconnected
480         */
481        bool Close();
482
483        /**
484         * Set the application layer negotiation protocol in the encrypted communication process
485         * @param alpnProtocols application layer negotiation protocol
486         * @return set whether the application layer negotiation protocol is successful during encrypted communication
487         */
488        bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
489
490        /**
491         * Storage of server communication related network information
492         * @param remoteInfo communication related network information
493         */
494        void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
495
496        /**
497         * convert the code to ssl error code
498         * @return the value for ssl error code.
499         */
500        int ConvertSSLError(void);
501
502        /**
503         * Get configuration options for encrypted communication process
504         * @return configuration options for encrypted communication processes
505         */
506        [[nodiscard]] TLSConfiguration GetTlsConfiguration() const;
507
508        /**
509         * Obtain the cipher suite during encrypted communication
510         * @return crypto suite used in encrypted communication
511         */
512        [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
513
514        /**
515         * Obtain the peer certificate used in encrypted communication
516         * @return peer certificate used in encrypted communication
517         */
518        [[nodiscard]] std::string GetRemoteCertificate() const;
519
520        /**
521         * Obtain the peer certificate used in encrypted communication
522         * @return peer certificate serialization data used in encrypted communication
523         */
524        [[nodiscard]] const X509CertRawData &GetRemoteCertRawData() const;
525
526        /**
527         * Obtain the certificate used in encrypted communication
528         * @return certificate serialization data used in encrypted communication
529         */
530        [[nodiscard]] const X509CertRawData &GetCertificate() const;
531
532        /**
533         * Get the encryption algorithm used in encrypted communication
534         * @return encryption algorithm used in encrypted communication
535         */
536        [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
537
538        /**
539         * Obtain the communication protocol used in encrypted communication
540         * @return communication protocol used in encrypted communication
541         */
542        [[nodiscard]] std::string GetProtocol() const;
543
544        /**
545         * Set the information about the shared signature algorithm supported by peers during encrypted communication
546         * @return information about peer supported shared signature algorithms
547         */
548        [[nodiscard]] bool SetSharedSigals();
549
550        /**
551         * Obtain the ssl used in encrypted communication
552         * @return SSL used in encrypted communication
553         */
554        [[nodiscard]] ssl_st *GetSSL();
555
556    private:
557        bool SendRetry(ssl_st *ssl, const char *curPos, size_t curSendSize, int sockfd);
558        bool StartTlsConnected(const TLSConnectOptions &options);
559        bool CreatTlsContext();
560        bool StartShakingHands(const TLSConnectOptions &options);
561        bool GetRemoteCertificateFromPeer();
562        bool SetRemoteCertRawData();
563        bool PollSend(int sockfd, ssl_st *ssl, const char *pdata, int sendSize);
564        std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
565        std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
566                                             const X509 *x509Certificates);
567
568    private:
569        std::mutex mutexForSsl_;
570        ssl_st *ssl_ = nullptr;
571        X509 *peerX509_ = nullptr;
572        uint16_t port_ = 0;
573        sa_family_t family_ = 0;
574        int32_t socketDescriptor_ = 0;
575
576        TLSContext tlsContext_;
577        TLSConfiguration configuration_;
578        Socket::NetAddress address_;
579        X509CertRawData remoteRawData_;
580
581        std::string hostName_;
582        std::string remoteCert_;
583
584        std::vector<std::string> signatureAlgorithms_;
585        std::unique_ptr<TLSContext> tlsContextPointer_ = nullptr;
586    };
587
588private:
589    TLSSocketInternal tlsSocketInternal_;
590
591    static std::string MakeAddressString(sockaddr *addr);
592
593    static void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
594                        socklen_t *len);
595
596    void CallOnMessageCallback(const std::string &data, const Socket::SocketRemoteInfo &remoteInfo);
597    void CallOnConnectCallback();
598    void CallOnCloseCallback();
599    void CallOnErrorCallback(int32_t err, const std::string &errString);
600
601    void CallBindCallback(int32_t err, BindCallback callback);
602    void CallConnectCallback(int32_t err, ConnectCallback callback);
603    void CallSendCallback(int32_t err, SendCallback callback);
604    void CallCloseCallback(int32_t err, CloseCallback callback);
605    void CallGetRemoteAddressCallback(int32_t err, const Socket::NetAddress &address,
606                                      GetRemoteAddressCallback callback);
607    void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, GetStateCallback callback);
608    void CallSetExtraOptionsCallback(int32_t err, SetExtraOptionsCallback callback);
609    void CallGetCertificateCallback(int32_t err, const X509CertRawData &cert, GetCertificateCallback callback);
610    void CallGetRemoteCertificateCallback(int32_t err, const X509CertRawData &cert,
611                                          GetRemoteCertificateCallback callback);
612    void CallGetProtocolCallback(int32_t err, const std::string &protocol, GetProtocolCallback callback);
613    void CallGetCipherSuiteCallback(int32_t err, const std::vector<std::string> &suite,
614                                    GetCipherSuiteCallback callback);
615    void CallGetSignatureAlgorithmsCallback(int32_t err, const std::vector<std::string> &algorithms,
616                                            GetSignatureAlgorithmsCallback callback);
617
618    int ReadMessage();
619    void StartReadMessage();
620
621    void GetIp4RemoteAddress(const GetRemoteAddressCallback &callback);
622    void GetIp6RemoteAddress(const GetRemoteAddressCallback &callback);
623
624    [[nodiscard]] bool SetBaseOptions(const Socket::ExtraOptionsBase &option) const;
625    [[nodiscard]] bool SetExtraOptions(const Socket::TCPExtraOptions &option) const;
626
627    void MakeIpSocket(sa_family_t family);
628
629    template<class T>
630    void DealCallback(int32_t err, T &callback)
631    {
632        T func = nullptr;
633        {
634            std::lock_guard<std::mutex> lock(mutex_);
635            if (callback) {
636                func = callback;
637            }
638        }
639
640        if (func) {
641            func(err);
642        }
643    }
644
645private:
646    static constexpr const size_t MAX_ERROR_LEN = 128;
647    static constexpr const size_t MAX_BUFFER_SIZE = 8192;
648
649    OnMessageCallback onMessageCallback_;
650    OnConnectCallback onConnectCallback_;
651    OnCloseCallback onCloseCallback_;
652    OnErrorCallback onErrorCallback_;
653
654    std::mutex mutex_;
655    std::mutex recvMutex_;
656    std::mutex cvMutex_;
657    bool isRunning_ = false;
658    bool isRunOver_ = true;
659    std::condition_variable cvSslFree_;
660    int sockFd_ = -1;
661    bool isExtSock_ = false;
662    Socket::NetAddress localAddress_;
663    bool isClosed = false;
664    std::mutex mutexForClose_;
665};
666} // namespace TlsSocket
667} // namespace NetStack
668} // namespace OHOS
669
670#endif // COMMUNICATIONNETSTACK_TLS_SOCEKT_H
671