1/*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <fstream>
17#include <gtest/gtest.h>
18#include <iostream>
19#include <openssl/rsa.h>
20#include <openssl/ssl.h>
21#include <sstream>
22#include <string>
23#include <string_view>
24#include <unistd.h>
25#include <vector>
26
27#include "net_address.h"
28#include "secure_data.h"
29#include "socket_error.h"
30#include "socket_state_base.h"
31#include "tls.h"
32#include "tls_certificate.h"
33#include "tls_configuration.h"
34#include "tls_key.h"
35#include "tls_socket_server.h"
36#include "tls_socket.h"
37
38namespace OHOS {
39namespace NetStack {
40namespace TlsSocketServer {
41namespace {
42const std::string_view CA_DER = "/data/ClientCert/ca.crt";
43const std::string_view IP_ADDRESS = "/data/Ip/address.txt";
44const std::string_view PORT = "/data/Ip/port.txt";
45
46inline bool CheckCaFileExistence(const char *function)
47{
48    if (access(CA_DER.data(), 0)) {
49        std::cout << "CA file does not exist! (" << function << ")";
50        return false;
51    }
52    return true;
53}
54
55std::string ChangeToFile(std::string_view fileName)
56{
57    std::ifstream file;
58    file.open(fileName);
59    std::stringstream ss;
60    ss << file.rdbuf();
61    std::string infos = ss.str();
62    file.close();
63    return infos;
64}
65
66
67std::string GetIp(std::string ip)
68{
69    return ip.substr(0, ip.length() - 1);
70}
71
72} // namespace
73class TlsSocketServerTest : public testing::Test {
74public:
75    static void SetUpTestCase() {}
76
77    static void TearDownTestCase() {}
78
79    virtual void SetUp() {}
80
81    virtual void TearDown() {}
82};
83
84HWTEST_F(TlsSocketServerTest, ListenInterface, testing::ext::TestSize.Level2)
85{
86    if (!CheckCaFileExistence("ListenInterface")) {
87        return;
88    }
89    TLSSocketServer server;
90    TlsSocket::TLSConnectOptions tlsListenOptions;
91
92    server.Listen(tlsListenOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
93}
94
95HWTEST_F(TlsSocketServerTest, sendInterface, testing::ext::TestSize.Level2)
96{
97    if (!CheckCaFileExistence("sendInterface")) {
98        return;
99    }
100
101    TLSSocketServer server;
102
103    TLSServerSendOptions tlsServerSendOptions;
104
105    const std::string data = "how do you do? this is sendInterface";
106    tlsServerSendOptions.SetSendData(data);
107    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
108}
109
110HWTEST_F(TlsSocketServerTest, closeInterface, testing::ext::TestSize.Level2)
111{
112    if (!CheckCaFileExistence("closeInterface")) {
113        return;
114    }
115
116    TLSSocketServer server;
117
118    const std::string data = "how do you do? this is closeInterface";
119    TLSServerSendOptions tlsServerSendOptions;
120    tlsServerSendOptions.SetSendData(data);
121    int socketFd =  tlsServerSendOptions.GetSocket();
122
123    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
124    sleep(2);
125
126    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
127}
128
129HWTEST_F(TlsSocketServerTest, stopInterface, testing::ext::TestSize.Level2)
130{
131    if (!CheckCaFileExistence("stopInterface")) {
132        return;
133    }
134
135    TLSSocketServer server;
136
137    TLSServerSendOptions tlsServerSendOptions;
138    int socketFd =  tlsServerSendOptions.GetSocket();
139
140
141    const std::string data = "how do you do? this is stopInterface";
142    tlsServerSendOptions.SetSendData(data);
143    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
144    sleep(2);
145
146
147    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
148    sleep(2);
149
150
151    server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
152}
153
154HWTEST_F(TlsSocketServerTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
155{
156    if (!CheckCaFileExistence("getRemoteAddressInterface")) {
157        return;
158    }
159
160    TLSSocketServer server;
161
162    TLSServerSendOptions tlsServerSendOptions;
163    int socketFd = tlsServerSendOptions.GetSocket();
164    Socket::NetAddress address;
165
166    address.SetAddress(GetIp(ChangeToFile(IP_ADDRESS)));
167    address.SetPort(std::atoi(ChangeToFile(PORT).c_str()));
168    address.SetFamilyBySaFamily(AF_INET);
169
170    Socket::NetAddress netAddress;
171    server.GetRemoteAddress(socketFd, [&netAddress](int32_t errCode,
172        const Socket::NetAddress &address) {
173    EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS);
174    netAddress.SetAddress(address.GetAddress());
175    netAddress.SetPort(address.GetPort());
176    netAddress.SetFamilyBySaFamily(address.GetSaFamily());
177    });
178
179    const std::string data = "how do you do? this is getRemoteAddressInterface";
180    tlsServerSendOptions.SetSendData(data);
181    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
182    sleep(2);
183
184    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
185    sleep(2);
186
187    server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
188}
189
190HWTEST_F(TlsSocketServerTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
191{
192    if (!CheckCaFileExistence("getRemoteCertificateInterface")) {
193        return;
194    }
195
196    TLSSocketServer server;
197
198    TLSServerSendOptions tlsServerSendOptions;
199    int socketFd = tlsServerSendOptions.GetSocket();
200
201
202    const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
203    tlsServerSendOptions.SetSendData(data);
204    server.Send(tlsServerSendOptions, [](int32_t errCode) {
205        EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
206    sleep(2);
207
208    server.GetRemoteCertificate(socketFd, [](int32_t errCode, const TlsSocket::X509CertRawData &cert) {
209        EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
210
211    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
212    sleep(2);
213
214    server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); });
215}
216
217HWTEST_F(TlsSocketServerTest, getCertificateInterface, testing::ext::TestSize.Level2)
218{
219    if (!CheckCaFileExistence("getCertificateInterface")) {
220        return;
221    }
222    TLSSocketServer server;
223
224    const std::string data = "how do you do? This is UT test getCertificateInterface";
225    TLSServerSendOptions tlsServerSendOptions;
226    tlsServerSendOptions.SetSendData(data);
227    int socketFd = tlsServerSendOptions.GetSocket();
228    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
229
230    server.GetCertificate(
231        [](int32_t errCode, const TlsSocket::X509CertRawData &cert) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
232
233    sleep(2);
234    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
235}
236
237HWTEST_F(TlsSocketServerTest, protocolInterface, testing::ext::TestSize.Level2)
238{
239    if (!CheckCaFileExistence("protocolInterface")) {
240        return;
241    }
242    TLSSocketServer server;
243
244    const std::string data = "how do you do? this is protocolInterface";
245    TLSServerSendOptions tlsServerSendOptions;
246    tlsServerSendOptions.SetSendData(data);
247
248    int socketFd = tlsServerSendOptions.GetSocket();
249    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
250    std::string getProtocolVal;
251    server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
252        EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS);
253        getProtocolVal = protocol;
254    });
255    EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
256
257    Socket::SocketStateBase stateBase;
258    server.GetState([&stateBase](int32_t errCode, Socket::SocketStateBase state) {
259        if (TlsSocket::TLSSOCKET_SUCCESS) {
260            EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS);
261            stateBase.SetIsBound(state.IsBound());
262            stateBase.SetIsClose(state.IsClose());
263            stateBase.SetIsConnected(state.IsConnected());
264        }
265    });
266    EXPECT_TRUE(stateBase.IsConnected());
267    sleep(2);
268
269    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
270}
271
272HWTEST_F(TlsSocketServerTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
273{
274    if (!CheckCaFileExistence("getSignatureAlgorithmsInterface")) {
275        return;
276    }
277
278    TLSSocketServer server;
279    TlsSocket::TLSSecureOptions secureOption;
280
281    const std::string data = "how do you do? this is getSigntureAlgorithmsInterface";
282    TLSServerSendOptions tlsServerSendOptions;
283    tlsServerSendOptions.SetSendData(data);
284
285    int socketFd = tlsServerSendOptions.GetSocket();
286    server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
287    sleep(2);
288
289    bool testFlag = false;
290    std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
291    secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
292    std::vector<std::string> testSignatureAlgorithms;
293    server.GetSignatureAlgorithms(socketFd, [&testSignatureAlgorithms](int32_t errCode,
294        const std::vector<std::string> &algorithms) {
295        if (errCode == TlsSocket::TLSSOCKET_SUCCESS) {
296            testSignatureAlgorithms = algorithms;
297        }
298    });
299    for (auto const &iter : testSignatureAlgorithms) {
300        if (iter == "ECDSA+SHA256") {
301            testFlag = true;
302        }
303    }
304    EXPECT_TRUE(testFlag);
305    sleep(2);
306
307
308    (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); });
309}
310
311
312} //TlsSocketServer
313} //NetStack
314} //OHOS
315