1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "net_address.h"
17#include "secure_data.h"
18#include "socket_error.h"
19#include "socket_state_base.h"
20#include "tls.h"
21#include "tls_certificate.h"
22#include "tls_configuration.h"
23#include "tls_key.h"
24#include "tls_socket.h"
25#include "tls_utils_test.h"
26
27namespace OHOS {
28namespace NetStack {
29namespace TlsSocket {
30void MockNetAddress(Socket::NetAddress &address)
31{
32    address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
33    address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
34    address.SetFamilyBySaFamily(AF_INET);
35}
36
37void MockTlsSocketParamOptions(Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
38{
39    secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM)));
40    secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT));
41
42    MockNetAddress(address);
43    options.SetTlsSecureOptions(secureOption);
44    options.SetNetAddress(address);
45}
46
47void SetSocketHwTestShortParam(TLSSocket &server)
48{
49    TLSConnectOptions options;
50    Socket::NetAddress address;
51    TLSSecureOptions secureOption;
52    std::vector<std::string> caVec1 = {TlsUtilsTest::ChangeToFile(CA_DER)};
53    secureOption.SetCaChain(caVec1);
54    MockTlsSocketParamOptions(address, secureOption, options);
55
56    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
57    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
58}
59
60void SetSocketHwTestLongParam(TLSSocket &server)
61{
62    TLSConnectOptions options;
63    TLSSecureOptions secureOption;
64    secureOption.SetCipherSuite("AES256-SHA256");
65    std::string protocolV13 = "TLSv1.3";
66    std::vector<std::string> protocolVec = {protocolV13};
67    secureOption.SetProtocolChain(protocolVec);
68    std::vector<std::string> caVect = {TlsUtilsTest::ChangeToFile(CA_DER)};
69    secureOption.SetCaChain(caVect);
70    Socket::NetAddress address;
71    MockTlsSocketParamOptions(address, secureOption, options);
72
73    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
74    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
75}
76
77HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
78{
79    if (!TlsUtilsTest::CheckCaFileExistence("bindInterface")) {
80        return;
81    }
82
83    Socket::NetAddress address;
84    TLSSocket bindTestServer;
85    MockNetAddress(address);
86    bindTestServer.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
87}
88
89HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
90{
91    if (!TlsUtilsTest::CheckCaFileExistence("connectInterface")) {
92        return;
93    }
94    TLSSocket testService;
95    SetSocketHwTestShortParam(testService);
96
97    const std::string data = "how do you do? this is connectInterface";
98    Socket::TCPSendOptions tcpSendOptions;
99    tcpSendOptions.SetData(data);
100    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
101    sleep(2);
102
103    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
104    sleep(2);
105}
106
107HWTEST_F(TlsSocketTest, startReadMessageInterface, testing::ext::TestSize.Level2)
108{
109    if (!TlsUtilsTest::CheckCaFileExistence("startReadMessageInterface")) {
110        return;
111    }
112    TLSSocket testService;
113    SetSocketHwTestShortParam(testService);
114
115    const std::string data = "how do you do? this is startReadMessageInterface";
116    Socket::TCPSendOptions tcpSendOptions;
117    tcpSendOptions.SetData(data);
118    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
119    sleep(2);
120
121    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
122}
123
124HWTEST_F(TlsSocketTest, readMessageInterface, testing::ext::TestSize.Level2)
125{
126    if (!TlsUtilsTest::CheckCaFileExistence("readMessageInterface")) {
127        return;
128    }
129    TLSSocket testService;
130    SetSocketHwTestShortParam(testService);
131
132    const std::string data = "how do you do? this is readMessageInterface";
133    Socket::TCPSendOptions tcpSendOptions;
134    tcpSendOptions.SetData(data);
135    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
136    sleep(2);
137
138    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
139}
140
141HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
142{
143    if (!TlsUtilsTest::CheckCaFileExistence("closeInterface")) {
144        return;
145    }
146
147    TLSSocket testService;
148    SetSocketHwTestShortParam(testService);
149
150    const std::string data = "how do you do? this is closeInterface";
151    Socket::TCPSendOptions tcpSendOptions;
152    tcpSendOptions.SetData(data);
153
154    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
155    sleep(2);
156
157    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
158}
159
160HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
161{
162    if (!TlsUtilsTest::CheckCaFileExistence("sendInterface")) {
163        return;
164    }
165    TLSSocket testService;
166    SetSocketHwTestShortParam(testService);
167
168    const std::string data = "how do you do? this is sendInterface";
169    Socket::TCPSendOptions tcpSendOptions;
170    tcpSendOptions.SetData(data);
171
172    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
173    sleep(2);
174
175    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
176}
177
178HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
179{
180    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteAddressInterface")) {
181        return;
182    }
183
184    TLSConnectOptions options;
185    TLSSocket testService;
186    TLSSecureOptions secureOption;
187    Socket::NetAddress address;
188    std::vector<std::string> caVec = {TlsUtilsTest::ChangeToFile(CA_DER)};
189    secureOption.SetCaChain(caVec);
190    MockTlsSocketParamOptions(address, secureOption, options);
191
192    testService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
193    testService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
194
195    Socket::NetAddress netAddress;
196    testService.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
197        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
198        netAddress.SetAddress(address.GetAddress());
199        netAddress.SetFamilyBySaFamily(address.GetSaFamily());
200        netAddress.SetPort(address.GetPort());
201    });
202    EXPECT_STREQ(netAddress.GetAddress().c_str(), TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
203    EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
204    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
205
206    const std::string data = "how do you do? this is getRemoteAddressInterface";
207    Socket::TCPSendOptions tcpSendOptions;
208    tcpSendOptions.SetData(data);
209
210    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
211
212    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
213}
214
215HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
216{
217    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteAddressInterface")) {
218        return;
219    }
220
221    TLSSocket testService;
222    SetSocketHwTestShortParam(testService);
223
224    Socket::SocketStateBase TlsSocketstate;
225    testService.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
226        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
227        TlsSocketstate = state;
228    });
229    std::cout << "TlsSocketTest TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
230    EXPECT_TRUE(TlsSocketstate.IsBound());
231    EXPECT_TRUE(!TlsSocketstate.IsClose());
232    EXPECT_TRUE(TlsSocketstate.IsConnected());
233
234    const std::string tlsSocketTestData = "how do you do? this is getStateInterface";
235    Socket::TCPSendOptions tcpSendOptions;
236    tcpSendOptions.SetData(tlsSocketTestData);
237    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
238
239    sleep(2);
240
241    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
242}
243
244HWTEST_F(TlsSocketTest, getCertificateInterface, testing::ext::TestSize.Level2)
245{
246    if (!TlsUtilsTest::CheckCaFileExistence("getCertificateInterface")) {
247        return;
248    }
249    TLSSocket testService;
250    SetSocketHwTestShortParam(testService);
251
252    const std::string data = "how do you do? This is UT test getCertificateInterface";
253    Socket::TCPSendOptions tcpSendOptions;
254    tcpSendOptions.SetData(data);
255    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
256
257    testService.GetCertificate(
258        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
259
260    sleep(2);
261    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
262}
263
264HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
265{
266    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteCertificateInterface")) {
267        return;
268    }
269    TLSSocket testService;
270    SetSocketHwTestShortParam(testService);
271
272    Socket::TCPSendOptions tcpSendOptions;
273    const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
274    tcpSendOptions.SetData(data);
275
276    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
277
278    testService.GetRemoteCertificate(
279        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
280
281    sleep(2);
282    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
283}
284
285HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
286{
287    if (!TlsUtilsTest::CheckCaFileExistence("protocolInterface")) {
288        return;
289    }
290    TLSSocket testService;
291    SetSocketHwTestLongParam(testService);
292
293    const std::string data = "how do you do? this is protocolInterface";
294    Socket::TCPSendOptions tcpSendOptions;
295    tcpSendOptions.SetData(data);
296
297    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
298    std::string protocolVal;
299    testService.GetProtocol([&protocolVal](int32_t errCode, const std::string &protocol) {
300        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
301        protocolVal = protocol;
302    });
303    EXPECT_STREQ(protocolVal.c_str(), "TLSv1.3");
304
305    Socket::SocketStateBase socketStateBase;
306    testService.GetState([&socketStateBase](int32_t errCode, Socket::SocketStateBase state) {
307        if (errCode == TLSSOCKET_SUCCESS) {
308            EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
309            socketStateBase.SetIsBound(state.IsBound());
310            socketStateBase.SetIsClose(state.IsClose());
311            socketStateBase.SetIsConnected(state.IsConnected());
312        }
313    });
314    EXPECT_TRUE(socketStateBase.IsConnected());
315    sleep(2);
316
317    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
318}
319
320HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
321{
322    if (!TlsUtilsTest::CheckCaFileExistence("getCipherSuiteInterface")) {
323        return;
324    }
325    TLSSocket testService;
326    SetSocketHwTestLongParam(testService);
327
328    bool flag = false;
329    const std::string data = "how do you do? This is getCipherSuiteInterface";
330    Socket::TCPSendOptions tcpSendOptions;
331    tcpSendOptions.SetData(data);
332    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
333
334    std::vector<std::string> cipherSuite;
335    testService.GetCipherSuite([&cipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
336        if (errCode == TLSSOCKET_SUCCESS) {
337            cipherSuite = suite;
338        }
339    });
340
341    for (auto const &iter : cipherSuite) {
342        if (iter == "AES256-SHA256") {
343            flag = true;
344        }
345    }
346
347    EXPECT_TRUE(flag);
348    sleep(2);
349
350    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
351}
352
353HWTEST_F(TlsSocketTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2)
354{
355    if (!TlsUtilsTest::CheckCaFileExistence("getSignatureAlgorithmsInterface")) {
356        return;
357    }
358    TLSSocket testService;
359    TLSConnectOptions options;
360    TLSSecureOptions secureOption;
361    Socket::NetAddress address;
362    std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"};
363    secureOption.SetSignatureAlgorithms(signatureAlgorithmVec);
364    std::string protocolV13 = "TLSv1.3";
365    std::vector<std::string> protocolVec = {protocolV13};
366    secureOption.SetProtocolChain(protocolVec);
367    std::vector<std::string> caVec = {TlsUtilsTest::ChangeToFile(CA_DER)};
368    secureOption.SetCaChain(caVec);
369    MockTlsSocketParamOptions(address, secureOption, options);
370
371    testService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
372    testService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
373
374    bool flag = false;
375    const std::string data = "how do you do? this is getSignatureAlgorithmsInterface";
376    Socket::TCPSendOptions tcpSendOptions;
377    tcpSendOptions.SetData(data);
378    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
379
380    std::vector<std::string> signatureAlgorithms;
381    testService.GetSignatureAlgorithms(
382        [&signatureAlgorithms](int32_t errCode, const std::vector<std::string> &algorithms) {
383            if (errCode == TLSSOCKET_SUCCESS) {
384                signatureAlgorithms = algorithms;
385            }
386        });
387
388    for (auto const &iter : signatureAlgorithms) {
389        if (iter == "ECDSA+SHA256") {
390            flag = true;
391        }
392    }
393    EXPECT_TRUE(flag);
394    sleep(2);
395    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
396}
397
398HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
399{
400    if (!TlsUtilsTest::CheckCaFileExistence("tlsSocketOnMessageData")) {
401        return;
402    }
403    std::string getData = "server->client";
404    TLSSocket testService;
405    SetSocketHwTestLongParam(testService);
406
407    testService.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
408        if (data == getData) {
409            EXPECT_TRUE(true);
410        } else {
411            EXPECT_TRUE(false);
412        }
413    });
414
415    const std::string data = "how do you do? this is tlsSocketOnMessageData";
416    Socket::TCPSendOptions tcpSendOptions;
417    tcpSendOptions.SetData(data);
418    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
419
420    sleep(2);
421    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
422}
423
424HWTEST_F(TlsSocketTest, upgradeInterface, testing::ext::TestSize.Level2)
425{
426    if (!TlsUtilsTest::CheckCaFileExistence("upgradeInterface")) {
427        return;
428    }
429
430    int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
431    EXPECT_TRUE(sock > 0);
432
433    sockaddr_in addr4 = {0};
434    Socket::NetAddress address;
435    MockNetAddress(address);
436    addr4.sin_family = AF_INET;
437    addr4.sin_port = htons(address.GetPort());
438    addr4.sin_addr.s_addr = inet_addr(address.GetAddress().c_str());
439
440    int ret = connect(sock, reinterpret_cast<sockaddr *>(&addr4), sizeof(sockaddr_in));
441    EXPECT_TRUE(ret >= 0);
442
443    TLSSocket testService(sock);
444    SetSocketHwTestShortParam(testService);
445
446    const std::string data = "how do you do? this is upgradeInterface";
447    Socket::TCPSendOptions tcpSendOptions;
448    tcpSendOptions.SetData(data);
449    testService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
450    sleep(2);
451
452    (void)testService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
453    sleep(2);
454}
455} // namespace TlsSocket
456} // namespace NetStack
457} // namespace OHOS
458