1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
17 #define COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
18 
19 #include "event_manager.h"
20 #include "extra_options_base.h"
21 #include "net_address.h"
22 #include "socket_error.h"
23 #include "socket_remote_info.h"
24 #include "socket_state_base.h"
25 #include "tcp_connect_options.h"
26 #include "tcp_extra_options.h"
27 #include "tcp_send_options.h"
28 #include "tls.h"
29 #include "tls_certificate.h"
30 #include "tls_configuration.h"
31 #include "tls_context_server.h"
32 #include "tls_key.h"
33 #include "tls_socket.h"
34 #include <any>
35 #include <condition_variable>
36 #include <cstring>
37 #include <functional>
38 #include <map>
39 #include <poll.h>
40 #include <thread>
41 #include <tuple>
42 #include <unistd.h>
43 #include <vector>
44 
45 namespace OHOS {
46 namespace NetStack {
47 namespace TlsSocketServer {
48 constexpr int USER_LIMIT = 10;
49 using OnMessageCallback =
50     std::function<void(const int &socketFd, const std::string &data, const Socket::SocketRemoteInfo &remoteInfo)>;
51 using OnCloseCallback = std::function<void(const int &socketFd)>;
52 using OnConnectCallback = std::function<void(const int &socketFd, std::shared_ptr<EventManager> eventManager)>;
53 using ListenCallback = std::function<void(int32_t errorNumber)>;
54 class TLSServerSendOptions {
55 public:
56     /**
57      * Set the socket ID to be transmitted
58      * @param socketFd Communication descriptor
59      */
60     void SetSocket(const int &socketFd);
61 
62     /**
63      * Set the data to send
64      * @param data Send data
65      */
66     void SetSendData(const std::string &data);
67 
68     /**
69      * Get the socket ID
70      * @return Gets the communication descriptor
71      */
72     [[nodiscard]] const int &GetSocket() const;
73 
74     /**
75      * Gets the data sent
76      * @return Send data
77      */
78     [[nodiscard]] const std::string &GetSendData() const;
79 
80 private:
81     int socketFd_;
82     std::string data_;
83 };
84 
85 class TLSSocketServer {
86 public:
87     TLSSocketServer(const TLSSocketServer &) = delete;
88     TLSSocketServer(TLSSocketServer &&) = delete;
89 
90     TLSSocketServer &operator=(const TLSSocketServer &) = delete;
91     TLSSocketServer &operator=(TLSSocketServer &&) = delete;
92 
93     TLSSocketServer() = default;
94     ~TLSSocketServer();
95 
96     /**
97      * Create sockets, bind and listen waiting for clients to connect
98      * @param tlsListenOptions Bind the listening connection configuration
99      * @param callback callback to the caller if bind ok or not
100      */
101     void Listen(const TlsSocket::TLSConnectOptions &tlsListenOptions, const ListenCallback &callback);
102 
103     /**
104      * Send data through an established encrypted connection
105      * @param data data sent over an established encrypted connection
106      * @return whether the data is successfully sent to the server
107      */
108     bool Send(const TLSServerSendOptions &data, const TlsSocket::SendCallback &callback);
109 
110     /**
111      * Disconnect by releasing the socket when communicating
112      * @param socketFd The socket ID of the client
113      * @param callback callback to the caller
114      */
115     void Close(const int socketFd, const TlsSocket::CloseCallback &callback);
116 
117     /**
118      * Disconnect by releasing the socket when communicating
119      * @param callback callback to the caller
120      */
121     void Stop(const TlsSocket::CloseCallback &callback);
122 
123     /**
124      * Get the peer network address
125      * @param socketFd The socket ID of the client
126      * @param callback callback to the caller
127      */
128     void GetRemoteAddress(const int socketFd, const TlsSocket::GetRemoteAddressCallback &callback);
129 
130     /**
131      * Get the peer network address
132      * @param socketFd The socket ID of the client
133      * @param callback callback to the caller
134      */
135     void GetLocalAddress(const int socketFd, const TlsSocket::GetLocalAddressCallback &callback);
136 
137     /**
138      * Get the status of the current socket
139      * @param callback callback to the caller
140      */
141     void GetState(const TlsSocket::GetStateCallback &callback);
142 
143     /**
144      * Gets or sets the options associated with the current socket
145      * @param tcpExtraOptions options associated with the current socket
146      * @param callback callback to the caller
147      */
148     bool SetExtraOptions(const Socket::TCPExtraOptions &tcpExtraOptions,
149                          const TlsSocket::SetExtraOptionsCallback &callback);
150 
151     /**
152      *  Get a local digital certificate
153      * @param callback callback to the caller
154      */
155     void GetCertificate(const TlsSocket::GetCertificateCallback &callback);
156 
157     /**
158      * Get the peer digital certificate
159      * @param socketFd The socket ID of the client
160      * @param needChain need chain
161      * @param callback callback to the caller
162      */
163     void GetRemoteCertificate(const int socketFd, const TlsSocket::GetRemoteCertificateCallback &callback);
164 
165     /**
166      * Obtain the protocol used in communication
167      * @param callback callback to the caller
168      */
169     void GetProtocol(const TlsSocket::GetProtocolCallback &callback);
170 
171     /**
172      * Obtain the cipher suite used in communication
173      * @param socketFd The socket ID of the client
174      * @param callback callback to the caller
175      */
176     void GetCipherSuite(const int socketFd, const TlsSocket::GetCipherSuiteCallback &callback);
177 
178     /**
179      * Obtain the encryption algorithm used in the communication process
180      * @param socketFd The socket ID of the client
181      * @param callback callback to the caller
182      */
183     void GetSignatureAlgorithms(const int socketFd, const TlsSocket::GetSignatureAlgorithmsCallback &callback);
184 
185     /**
186      * Register the callback that is called when the connection is disconnected
187      * @param onCloseCallback callback invoked when disconnected
188      */
189 
190     /**
191      * Register the callback that is called when the connection is established
192      * @param onConnectCallback callback invoked when connection is established
193      */
194     void OnConnect(const OnConnectCallback &onConnectCallback);
195 
196     /**
197      * Register the callback that is called when an error occurs
198      * @param onErrorCallback callback invoked when an error occurs
199      */
200     void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
201 
202     /**
203      * Off Connect
204      */
205     void OffConnect();
206 
207     /**
208      * Off Error
209      */
210     void OffError();
211 
212     /**
213      * Get the socket file description of the server
214      */
215     int GetListenSocketFd();
216 
217     /**
218      * Set the current socket file description address of the server
219      */
220     void SetLocalAddress(const Socket::NetAddress &address);
221 
222     /**
223      * Get the current socket file description address of the server
224      */
225     Socket::NetAddress GetLocalAddress();
226 
227 public:
228     class Connection : public std::enable_shared_from_this<Connection> {
229     public:
230         ~Connection();
231         /**
232          * Establish an encrypted accept on the specified socket
233          * @param sock socket for establishing encrypted connection
234          * @param options some options required during tls accept
235          * @return whether the encrypted accept is successfully established
236          */
237         bool TlsAcceptToHost(int sock, const TlsSocket::TLSConnectOptions &options);
238 
239         /**
240          * Set the configuration items for establishing encrypted connections
241          * @param config configuration item when establishing encrypted connection
242          */
243         void SetTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
244 
245         /**
246          * Set address information
247          */
248         void SetAddress(const Socket::NetAddress address);
249 
250         /**
251          * Set local address information
252          */
253         void SetLocalAddress(const Socket::NetAddress address);
254 
255         /**
256          * Send data through an established encrypted connection
257          * @param data data sent over an established encrypted connection
258          * @return whether the data is successfully sent to the server
259          */
260         bool Send(const std::string &data);
261 
262         /**
263          * Receive the data sent by the server through the established encrypted connection
264          * @param buffer receive the data sent by the server
265          * @param maxBufferSize the size of the data received from the server
266          * @return whether the data sent by the server is successfully received
267          */
268         int Recv(char *buffer, int maxBufferSize);
269 
270         /**
271          * Disconnect encrypted connection
272          * @return whether the encrypted connection was successfully disconnected
273          */
274         bool Close();
275 
276         /**
277          * Set the application layer negotiation protocol in the encrypted communication process
278          * @param alpnProtocols application layer negotiation protocol
279          * @return set whether the application layer negotiation protocol is successful during encrypted communication
280          */
281         bool SetAlpnProtocols(const std::vector<std::string> &alpnProtocols);
282 
283         /**
284          * Storage of server communication related network information
285          * @param remoteInfo communication related network information
286          */
287         void MakeRemoteInfo(Socket::SocketRemoteInfo &remoteInfo);
288 
289         /**
290          * Get configuration options for encrypted communication process
291          * @return configuration options for encrypted communication processes
292          */
293         [[nodiscard]] TlsSocket::TLSConfiguration GetTlsConfiguration() const;
294 
295         /**
296          * Obtain the cipher suite during encrypted communication
297          * @return crypto suite used in encrypted communication
298          */
299         [[nodiscard]] std::vector<std::string> GetCipherSuite() const;
300 
301         /**
302          * Obtain the peer certificate used in encrypted communication
303          * @return peer certificate used in encrypted communication
304          */
305         [[nodiscard]] std::string GetRemoteCertificate() const;
306 
307         /**
308          * Obtain the peer certificate used in encrypted communication
309          * @return peer certificate serialization data used in encrypted communication
310          */
311         [[nodiscard]] const TlsSocket::X509CertRawData &GetRemoteCertRawData() const;
312 
313         /**
314          * Obtain the certificate used in encrypted communication
315          * @return certificate serialization data used in encrypted communication
316          */
317         [[nodiscard]] const TlsSocket::X509CertRawData &GetCertificate() const;
318 
319         /**
320          * Get the encryption algorithm used in encrypted communication
321          * @return encryption algorithm used in encrypted communication
322          */
323         [[nodiscard]] std::vector<std::string> GetSignatureAlgorithms() const;
324 
325         /**
326          * Obtain the communication protocol used in encrypted communication
327          * @return communication protocol used in encrypted communication
328          */
329         [[nodiscard]] std::string GetProtocol() const;
330 
331         /**
332          * Set the information about the shared signature algorithm supported by peers during encrypted communication
333          * @return information about peer supported shared signature algorithms
334          */
335         [[nodiscard]] bool SetSharedSigals();
336 
337         /**
338          * Obtain the ssl used in encrypted communication
339          * @return SSL used in encrypted communication
340          */
341         [[nodiscard]] ssl_st *GetSSL() const;
342 
343         /**
344          * Get address information
345          * @return Returns the address information of the remote client
346          */
347         [[nodiscard]] Socket::NetAddress GetAddress() const;
348 
349         /**
350          * Get local address information
351          * @return Returns the address information of the local accept connect
352          */
353         [[nodiscard]] Socket::NetAddress GetLocalAddress() const;
354 
355         /**
356          * Get address information
357          * @return Returns the address information of the remote client
358          */
359         [[nodiscard]] int GetSocketFd() const;
360 
361         /**
362          * Get EventManager information
363          * @return Returns the address information of the remote client
364          */
365         [[nodiscard]] std::shared_ptr<EventManager> GetEventManager() const;
366 
367         void OnMessage(const OnMessageCallback &onMessageCallback);
368         /**
369          * Unregister the callback which is called when message is received
370          */
371         void OffMessage();
372 
373         void CallOnMessageCallback(int32_t socketFd, const std::string &data,
374                                    const Socket::SocketRemoteInfo &remoteInfo);
375 
376         void SetEventManager(std::shared_ptr<EventManager> eventManager);
377 
378         void SetClientID(int32_t clientID);
379 
380         [[nodiscard]] int GetClientID();
381 
382         void CallOnCloseCallback(const int32_t socketFd);
383         void OnClose(const OnCloseCallback &onCloseCallback);
384         OnCloseCallback onCloseCallback_;
385 
386         /**
387          * Off Close
388          */
389         void OffClose();
390 
391         /**
392          * Register the callback that is called when an error occurs
393          * @param onErrorCallback callback invoked when an error occurs
394          */
395         void OnError(const TlsSocket::OnErrorCallback &onErrorCallback);
396         /**
397          * Off Error
398          */
399         void OffError();
400 
401         void CallOnErrorCallback(int32_t err, const std::string &errString);
402 
403         TlsSocket::OnErrorCallback onErrorCallback_;
404 
405     private:
406         bool StartTlsAccept(const TlsSocket::TLSConnectOptions &options);
407         bool CreatTlsContext();
408         bool StartShakingHands(const TlsSocket::TLSConnectOptions &options);
409         bool GetRemoteCertificateFromPeer();
410         bool SetRemoteCertRawData();
411         std::string CheckServerIdentityLegal(const std::string &hostName, const X509 *x509Certificates);
412         std::string CheckServerIdentityLegal(const std::string &hostName, X509_EXTENSION *ext,
413                                              const X509 *x509Certificates);
414 
415     private:
416         ssl_st *ssl_ = nullptr;
417         X509 *peerX509_ = nullptr;
418         int32_t socketFd_ = 0;
419 
420         TlsSocket::TLSContextServer tlsContext_;
421         TlsSocket::TLSConfiguration connectionConfiguration_;
422         Socket::NetAddress address_;
423         Socket::NetAddress localAddress_;
424         TlsSocket::X509CertRawData remoteRawData_;
425 
426         std::string hostName_;
427         std::string remoteCert_;
428         std::string keyPass_;
429 
430         std::vector<std::string> signatureAlgorithms_;
431         std::unique_ptr<TlsSocket::TLSContextServer> tlsContextServerPointer_ = nullptr;
432 
433         std::shared_ptr<EventManager> eventManager_ = nullptr;
434         int32_t clientID_ = 0;
435         OnMessageCallback onMessageCallback_;
436     };
437 
438 private:
439     void SetLocalTlsConfiguration(const TlsSocket::TLSConnectOptions &config);
440     int RecvRemoteInfo(int socketFd, int index);
441     void RemoveConnect(int socketFd);
442     void AddConnect(int socketFd, std::shared_ptr<Connection> connection);
443     void CallListenCallback(int32_t err, ListenCallback callback);
444     void CallOnErrorCallback(int32_t err, const std::string &errString);
445 
446     void CallGetStateCallback(int32_t err, const Socket::SocketStateBase &state, TlsSocket::GetStateCallback callback);
447     void CallOnConnectCallback(const int32_t socketFd, std::shared_ptr<EventManager> eventManager);
448     void CallSendCallback(int32_t err, TlsSocket::SendCallback callback);
449     bool ExecBind(const Socket::NetAddress &address, const ListenCallback &callback);
450     void ExecAccept(const TlsSocket::TLSConnectOptions &tlsAcceptOptions, const ListenCallback &callback);
451     void MakeIpSocket(sa_family_t family);
452     void GetAddr(const Socket::NetAddress &address, sockaddr_in *addr4, sockaddr_in6 *addr6, sockaddr **addr,
453                  socklen_t *len);
454     static constexpr const size_t MAX_ERROR_LEN = 128;
455     static constexpr const size_t MAX_BUFFER_SIZE = 8192;
456 
457     void PollThread(const TlsSocket::TLSConnectOptions &tlsListenOptions);
458 
459 private:
460     std::mutex mutex_;
461     std::mutex connectMutex_;
462     int listenSocketFd_ = -1;
463     Socket::NetAddress address_;
464     Socket::NetAddress localAddress_;
465 
466     std::map<int, std::shared_ptr<Connection>> clientIdConnections_;
467     TlsSocket::TLSConfiguration TLSServerConfiguration_;
468 
469     OnConnectCallback onConnectCallback_;
470     TlsSocket::OnErrorCallback onErrorCallback_;
471 
472     bool GetTlsConnectionLocalAddress(int acceptSockFD, Socket::NetAddress &localAddress);
473     void ProcessTcpAccept(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientId);
474     void DropFdFromPollList(int &fd_index);
475     void InitPollList(int &listendFd);
476 
477     struct pollfd fds_[USER_LIMIT + 1];
478 
479     bool isRunning_;
480 
481 public:
482     std::shared_ptr<Connection> GetConnectionByClientID(int clientid);
483     int GetConnectionClientCount();
484 
485     std::shared_ptr<Connection> GetConnectionByClientEventManager(const EventManager *eventManager);
486     void CloseConnectionByEventManager(EventManager *eventManager);
487     void DeleteConnectionByEventManager(EventManager *eventManager);
488     void SetTlsConnectionSecureOptions(const TlsSocket::TLSConnectOptions &tlsListenOptions, int clientID,
489                                        int connectFD, std::shared_ptr<Connection> &connection);
490 };
491 } // namespace TlsSocketServer
492 } // namespace NetStack
493 } // namespace OHOS
494 
495 #endif // COMMUNICATIONNETSTACK_TLS_SERVER_SOCEKT_H
496