1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "net_address.h"
17#include "secure_data.h"
18#include "socket_error.h"
19#include "socket_state_base.h"
20#include "tls.h"
21#include "tls_certificate.h"
22#include "tls_configuration.h"
23#include "tls_key.h"
24#include "tls_socket.h"
25#include "tls_utils_test.h"
26
27namespace OHOS {
28namespace NetStack {
29namespace TlsSocket {
30void MockCertChainNetAddress(Socket::NetAddress &address)
31{
32    address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
33    address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
34    address.SetFamilyBySaFamily(AF_INET);
35}
36
37void MockCertChainParamOptions(Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
38{
39    secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM_CHAIN)));
40    secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT_CHAIN));
41
42    MockCertChainNetAddress(address);
43    options.SetNetAddress(address);
44    options.SetTlsSecureOptions(secureOption);
45}
46
47void SetCertChainHwTestShortParam(TLSSocket &server)
48{
49    TLSConnectOptions options;
50    TLSSecureOptions secureOption;
51    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
52        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
53    secureOption.SetCaChain(caVec);
54    Socket::NetAddress address;
55    MockCertChainParamOptions(address, secureOption, options);
56
57    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
58    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
59}
60
61void SetCertChainHwTestLongParam(TLSSocket &server)
62{
63    Socket::NetAddress address;
64    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
65        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
66    TLSSecureOptions secureOption;
67    secureOption.SetCaChain(caVec);
68    std::string protocolV13 = "TLSv1.3";
69    std::vector<std::string> protocolVec = { protocolV13 };
70    secureOption.SetProtocolChain(protocolVec);
71    secureOption.SetCipherSuite("AES256-SHA256");
72
73    TLSConnectOptions options;
74    MockCertChainParamOptions(address, secureOption, options);
75
76    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
77    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
78}
79
80HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
81{
82    if (!TlsUtilsTest::CheckCaPathChainExistence("bindInterface")) {
83        return;
84    }
85
86    TLSSocket testServer;
87    Socket::NetAddress address;
88    MockCertChainNetAddress(address);
89    testServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
90}
91
92HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
93{
94    if (!TlsUtilsTest::CheckCaPathChainExistence("connectInterface")) {
95        return;
96    }
97    TLSSocket certChainServer;
98    SetCertChainHwTestShortParam(certChainServer);
99
100    const std::string data = "how do you do? this is connectInterface";
101    Socket::TCPSendOptions tcpSendOptions;
102    tcpSendOptions.SetData(data);
103    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
104    sleep(2);
105
106    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
107    sleep(2);
108}
109
110HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
111{
112    if (!TlsUtilsTest::CheckCaPathChainExistence("closeInterface")) {
113        return;
114    }
115    TLSSocket certChainServer;
116    SetCertChainHwTestShortParam(certChainServer);
117
118    const std::string data = "how do you do? this is closeInterface";
119    Socket::TCPSendOptions tcpSendOptions;
120    tcpSendOptions.SetData(data);
121
122    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
123    sleep(2);
124
125    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
126}
127
128HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
129{
130    if (!TlsUtilsTest::CheckCaPathChainExistence("sendInterface")) {
131        return;
132    }
133    TLSSocket certChainServer;
134    SetCertChainHwTestShortParam(certChainServer);
135
136    const std::string data = "how do you do? this is sendInterface";
137    Socket::TCPSendOptions tcpSendOptions;
138    tcpSendOptions.SetData(data);
139
140    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
141    sleep(2);
142
143    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
144}
145
146HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
147{
148    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
149        return;
150    }
151    TLSSocket certChainServer;
152    TLSConnectOptions options;
153    TLSSecureOptions secureOption;
154    Socket::NetAddress address;
155    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
156        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
157    secureOption.SetCaChain(caVec);
158    MockCertChainParamOptions(address, secureOption, options);
159
160    certChainServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
161    certChainServer.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
162
163    Socket::NetAddress netAddress;
164    certChainServer.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
165        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
166        netAddress.SetPort(address.GetPort());
167        netAddress.SetFamilyBySaFamily(address.GetSaFamily());
168        netAddress.SetAddress(address.GetAddress());
169    });
170    EXPECT_STREQ(netAddress.GetAddress().c_str(), TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
171    EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
172    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
173
174    const std::string data = "how do you do? this is getRemoteAddressInterface";
175    Socket::TCPSendOptions tcpSendOptions;
176    tcpSendOptions.SetData(data);
177
178    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
179
180    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
181}
182
183HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
184{
185    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
186        return;
187    }
188    TLSSocket certChainServer;
189    SetCertChainHwTestShortParam(certChainServer);
190
191    Socket::SocketStateBase TlsSocketstate;
192    certChainServer.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
193        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
194        TlsSocketstate = state;
195    });
196    std::cout << "TlsSocketCertChainTest TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
197    EXPECT_TRUE(TlsSocketstate.IsBound());
198    EXPECT_TRUE(!TlsSocketstate.IsClose());
199    EXPECT_TRUE(TlsSocketstate.IsConnected());
200
201    const std::string tlsSocketCertChainTestData = "how do you do? this is getStateInterface";
202    Socket::TCPSendOptions tcpSendOptions;
203    tcpSendOptions.SetData(tlsSocketCertChainTestData);
204    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
205
206    sleep(2);
207
208    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
209}
210
211HWTEST_F(TlsSocketTest, getCertificateInterface, testing::ext::TestSize.Level2)
212{
213    if (!TlsUtilsTest::CheckCaPathChainExistence("getCertificateInterface")) {
214        return;
215    }
216    TLSSocket certChainServer;
217    SetCertChainHwTestShortParam(certChainServer);
218    Socket::TCPSendOptions tcpSendOptions;
219    const std::string data = "how do you do? This is UT test getCertificateInterface";
220
221    tcpSendOptions.SetData(data);
222    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
223
224    certChainServer.GetCertificate(
225        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
226
227    sleep(2);
228    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
229}
230
231HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
232{
233    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteCertificateInterface")) {
234        return;
235    }
236    TLSSocket certChainServer;
237    SetCertChainHwTestShortParam(certChainServer);
238    Socket::TCPSendOptions tcpSendOptions;
239    const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
240    tcpSendOptions.SetData(data);
241
242    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
243
244    certChainServer.GetRemoteCertificate(
245        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
246
247    sleep(2);
248    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
249}
250
251HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
252{
253    if (!TlsUtilsTest::CheckCaPathChainExistence("protocolInterface")) {
254        return;
255    }
256    TLSSocket certChainServer;
257    SetCertChainHwTestLongParam(certChainServer);
258
259    const std::string data = "how do you do? this is protocolInterface.";
260    Socket::TCPSendOptions tcpSendOptions;
261    tcpSendOptions.SetData(data);
262
263    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
264    std::string getProtocolVal = "";
265    certChainServer.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
266        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
267        getProtocolVal = protocol;
268    });
269    EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3");
270    sleep(2);
271
272    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
273}
274
275HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
276{
277    if (!TlsUtilsTest::CheckCaPathChainExistence("getCipherSuiteInterface")) {
278        return;
279    }
280    TLSSocket certChainServer;
281    SetCertChainHwTestLongParam(certChainServer);
282
283    bool successFlag = false;
284    const std::string data = "how do you do? This is getCipherSuiteInterface";
285    Socket::TCPSendOptions testTcpSendOptions;
286    testTcpSendOptions.SetData(data);
287    certChainServer.Send(testTcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
288
289    std::vector<std::string> testCipherSuite;
290    certChainServer.GetCipherSuite([&testCipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
291        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
292        testCipherSuite = suite;
293    });
294
295    for (auto const &iter : testCipherSuite) {
296        if (iter == "AES256-SHA256") {
297            successFlag = true;
298        }
299    }
300
301    EXPECT_TRUE(successFlag);
302    sleep(2);
303
304    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
305}
306
307HWTEST_F(TlsSocketTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
308{
309    if (!TlsUtilsTest::CheckCaPathChainExistence("getSignatureAlgorithmsInterface")) {
310        return;
311    }
312
313    TLSSocket certChainServer;
314    TLSSecureOptions secureOption;
315    std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
316    secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
317    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
318        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
319    secureOption.SetCaChain(caVec);
320    std::string protocolV13 = "TLSv1.3";
321    std::vector<std::string> protocolVec = {protocolV13};
322    secureOption.SetProtocolChain(protocolVec);
323    Socket::NetAddress address;
324    TLSConnectOptions options;
325    MockCertChainParamOptions(address, secureOption, options);
326
327    bool successFlag = false;
328    certChainServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
329    certChainServer.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
330
331    const std::string data = "how do you do? this is getSignatureAlgorithmsInterface";
332    Socket::TCPSendOptions testOptions;
333    testOptions.SetData(data);
334    certChainServer.Send(testOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
335
336    std::vector<std::string> testSignatureAlgorithms;
337    certChainServer.GetSignatureAlgorithms(
338        [&testSignatureAlgorithms](int32_t errCode, const std::vector<std::string> &algorithms) {
339            if (errCode == TLSSOCKET_SUCCESS) {
340                testSignatureAlgorithms = algorithms;
341            }
342        });
343
344    for (auto const &iter : testSignatureAlgorithms) {
345        if (iter == "ECDSA+SHA256") {
346            successFlag = true;
347        }
348    }
349    EXPECT_TRUE(successFlag);
350    sleep(2);
351    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
352}
353
354HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
355{
356    if (!TlsUtilsTest::CheckCaPathChainExistence("tlsSocketOnMessageData")) {
357        return;
358    }
359    std::string getData = "server->client";
360    TLSSocket certChainServer;
361    SetCertChainHwTestLongParam(certChainServer);
362    certChainServer.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
363        if (data == getData) {
364            EXPECT_TRUE(true);
365        } else {
366            EXPECT_TRUE(false);
367        }
368    });
369
370    const std::string data = "how do you do? this is tlsSocketOnMessageData";
371    Socket::TCPSendOptions tcpSendOptions;
372    tcpSendOptions.SetData(data);
373    certChainServer.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
374
375    sleep(2);
376    (void)certChainServer.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
377}
378} // namespace TlsSocket
379} // namespace NetStack
380} // namespace OHOS
381