1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "net_address.h"
17#include "secure_data.h"
18#include "socket_error.h"
19#include "socket_state_base.h"
20#include "tls.h"
21#include "tls_certificate.h"
22#include "tls_configuration.h"
23#include "tls_key.h"
24#include "tls_socket.h"
25#include "tls_utils_test.h"
26
27namespace OHOS {
28namespace NetStack {
29namespace TlsSocket {
30void MockCertChainOneWayNetAddress(Socket::NetAddress &address)
31{
32    address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
33    address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
34    address.SetFamilyBySaFamily(AF_INET);
35}
36
37void MockCertChainOneWayParamOptions(
38    Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
39{
40    secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM_CHAIN)));
41    secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT_CHAIN));
42
43    MockCertChainOneWayNetAddress(address);
44    options.SetTlsSecureOptions(secureOption);
45    options.SetNetAddress(address);
46}
47
48void SetCertChainOneWayHwTestShortParam(TLSSocket &server)
49{
50    TLSConnectOptions options;
51    TLSSecureOptions secureOption;
52    Socket::NetAddress address;
53
54    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
55        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
56    secureOption.SetCaChain(caVec);
57    MockCertChainOneWayParamOptions(address, secureOption, options);
58    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
59    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
60}
61
62void SetCertChainOneWayHwTestLongParam(TLSSocket &server)
63{
64    TLSConnectOptions options;
65    TLSSecureOptions secureOption;
66    Socket::NetAddress address;
67
68    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
69        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
70    std::string protocolV13 = "TLSv1.3";
71    std::vector<std::string> protocolVec = { protocolV13 };
72    secureOption.SetCaChain(caVec);
73    secureOption.SetCipherSuite("AES256-SHA256");
74    secureOption.SetProtocolChain(protocolVec);
75    MockCertChainOneWayParamOptions(address, secureOption, options);
76
77    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
78    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
79}
80
81HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
82{
83    if (!TlsUtilsTest::CheckCaPathChainExistence("bindInterface")) {
84        return;
85    }
86
87    TLSSocket srv;
88    Socket::NetAddress address;
89    MockCertChainOneWayNetAddress(address);
90    srv.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
91}
92
93HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
94{
95    if (!TlsUtilsTest::CheckCaPathChainExistence("connectInterface")) {
96        return;
97    }
98    TLSSocket certChainOneWayService;
99    SetCertChainOneWayHwTestShortParam(certChainOneWayService);
100
101    const std::string data = "how do you do? this is connectInterface";
102    Socket::TCPSendOptions tcpSendOptions;
103    tcpSendOptions.SetData(data);
104    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
105    sleep(2);
106
107    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
108    sleep(2);
109}
110
111HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
112{
113    if (!TlsUtilsTest::CheckCaPathChainExistence("closeInterface")) {
114        return;
115    }
116    TLSSocket certChainOneWayService;
117    SetCertChainOneWayHwTestShortParam(certChainOneWayService);
118
119    const std::string data = "how do you do? this is closeInterface";
120    Socket::TCPSendOptions tcpSendOptions;
121    tcpSendOptions.SetData(data);
122
123    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
124    sleep(2);
125
126    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
127}
128
129HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
130{
131    if (!TlsUtilsTest::CheckCaPathChainExistence("sendInterface")) {
132        return;
133    }
134    TLSSocket certChainOneWayService;
135    SetCertChainOneWayHwTestShortParam(certChainOneWayService);
136
137    const std::string data = "how do you do? this is sendInterface";
138    Socket::TCPSendOptions tcpSendOptions;
139    tcpSendOptions.SetData(data);
140
141    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
142    sleep(2);
143
144    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
145}
146
147HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
148{
149    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
150        return;
151    }
152    TLSSocket certChainOneWayService;
153    TLSConnectOptions options;
154    TLSSecureOptions secureOption;
155    Socket::NetAddress address;
156
157    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(CA_PATH_CHAIN),
158        TlsUtilsTest::ChangeToFile(MID_CA_PATH_CHAIN) };
159    secureOption.SetCaChain(caVec);
160    MockCertChainOneWayParamOptions(address, secureOption, options);
161
162    certChainOneWayService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
163    certChainOneWayService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
164
165    Socket::NetAddress netAddress;
166    certChainOneWayService.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
167        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
168        netAddress.SetPort(address.GetPort());
169        netAddress.SetAddress(address.GetAddress());
170        netAddress.SetFamilyBySaFamily(address.GetSaFamily());
171    });
172    EXPECT_STREQ(netAddress.GetAddress().c_str(),
173        TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
174    EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
175    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
176
177    const std::string data = "how do you do? this is getRemoteAddressInterface";
178    Socket::TCPSendOptions tcpSendOptions;
179    tcpSendOptions.SetData(data);
180
181    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
182    sleep(2);
183
184    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
185}
186
187HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
188{
189    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
190        return;
191    }
192    TLSSocket certChainOneWayService;
193    SetCertChainOneWayHwTestShortParam(certChainOneWayService);
194
195    Socket::SocketStateBase TlsSocketstate;
196    certChainOneWayService.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
197        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
198        TlsSocketstate = state;
199    });
200    std::cout << "TlsSocketOneWayTest TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
201    EXPECT_TRUE(TlsSocketstate.IsBound());
202    EXPECT_TRUE(!TlsSocketstate.IsClose());
203    EXPECT_TRUE(TlsSocketstate.IsConnected());
204
205    const std::string tlsSocketOneWayTestData = "how do you do? this is getStateInterface";
206    Socket::TCPSendOptions tcpSendOptions;
207    tcpSendOptions.SetData(tlsSocketOneWayTestData);
208    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
209
210    sleep(2);
211
212    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
213}
214
215HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
216{
217    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteCertificateInterface")) {
218        return;
219    }
220    TLSSocket certChainOneWayService;
221    SetCertChainOneWayHwTestShortParam(certChainOneWayService);
222    Socket::TCPSendOptions tcpSendOptions;
223
224    const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
225    tcpSendOptions.SetData(data);
226    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
227
228    certChainOneWayService.GetRemoteCertificate(
229        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
230
231    sleep(2);
232    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
233}
234
235HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
236{
237    if (!TlsUtilsTest::CheckCaPathChainExistence("protocolInterface")) {
238        return;
239    }
240    TLSSocket certChainOneWayService;
241    SetCertChainOneWayHwTestLongParam(certChainOneWayService);
242
243    const std::string testData = "how do you do? this is protocolInterface";
244    Socket::TCPSendOptions tcpSendOptions;
245    tcpSendOptions.SetData(testData);
246
247    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
248    std::string getProtocolResult;
249    certChainOneWayService.GetProtocol([&getProtocolResult](int32_t errCode, const std::string &protocol) {
250        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
251        getProtocolResult = protocol;
252    });
253    EXPECT_STREQ(getProtocolResult.c_str(), "TLSv1.3");
254    sleep(2);
255
256    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
257}
258
259HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
260{
261    if (!TlsUtilsTest::CheckCaPathChainExistence("getCipherSuiteInterface")) {
262        return;
263    }
264    TLSSocket certChainOneWayService;
265    SetCertChainOneWayHwTestLongParam(certChainOneWayService);
266
267    bool oneWayTestFlag = false;
268    const std::string data = "how do you do? This is getCipherSuiteInterface";
269    Socket::TCPSendOptions tcpSendOptions;
270    tcpSendOptions.SetData(data);
271    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
272
273    std::vector<std::string> oneWayTestSuite;
274    certChainOneWayService.GetCipherSuite([&oneWayTestSuite](int32_t errCode, const std::vector<std::string> &suite) {
275        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
276        oneWayTestSuite = suite;
277    });
278
279    for (auto const &iter : oneWayTestSuite) {
280        if (iter == "AES256-SHA256") {
281            oneWayTestFlag = true;
282        }
283    }
284
285    EXPECT_TRUE(oneWayTestFlag);
286    sleep(2);
287
288    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
289}
290
291HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
292{
293    if (!TlsUtilsTest::CheckCaPathChainExistence("tlsSocketOnMessageData")) {
294        return;
295    }
296    std::string getData = "server->client";
297    TLSSocket certChainOneWayService;
298    SetCertChainOneWayHwTestLongParam(certChainOneWayService);
299
300    certChainOneWayService.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
301        if (data == getData) {
302            EXPECT_TRUE(true);
303        } else {
304            EXPECT_TRUE(false);
305        }
306    });
307
308    const std::string data = "how do you do? this is tlsSocketOnMessageData";
309    Socket::TCPSendOptions tcpSendOptions;
310    tcpSendOptions.SetData(data);
311    certChainOneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
312
313    sleep(2);
314    (void)certChainOneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
315}
316} // namespace TlsSocket
317} // namespace NetStack
318} // namespace OHOS
319