1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "net_address.h"
17#include "secure_data.h"
18#include "socket_error.h"
19#include "socket_state_base.h"
20#include "tls.h"
21#include "tls_certificate.h"
22#include "tls_configuration.h"
23#include "tls_key.h"
24#include "tls_socket.h"
25#include "tls_utils_test.h"
26
27namespace OHOS {
28namespace NetStack {
29namespace TlsSocket {
30void MockOneWayNetAddress(Socket::NetAddress &address)
31{
32    address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
33    address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
34    address.SetFamilyBySaFamily(AF_INET);
35}
36
37void MockOneWayParamOptions(Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
38{
39    secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM)));
40    secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT));
41
42    MockOneWayNetAddress(address);
43    options.SetTlsSecureOptions(secureOption);
44    options.SetNetAddress(address);
45}
46
47void SetOneWayHwTestShortParam(TLSSocket &server)
48{
49    TLSSecureOptions secureOption;
50    Socket::NetAddress address;
51    TLSConnectOptions options;
52    std::vector<std::string> caVec = {TlsUtilsTest::ChangeToFile(CA_DER)};
53    secureOption.SetCaChain(caVec);
54    MockOneWayParamOptions(address, secureOption, options);
55
56    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
57    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
58}
59
60void SetOneWayHwTestLongParam(TLSSocket &server)
61{
62    Socket::NetAddress address;
63    TLSSecureOptions secureOption;
64
65    std::vector<std::string> caVec = {TlsUtilsTest::ChangeToFile(CA_DER)};
66    secureOption.SetCaChain(caVec);
67    secureOption.SetCipherSuite("AES256-SHA256");
68    std::string protocolV13 = "TLSv1.3";
69    std::vector<std::string> protocolVec = {protocolV13};
70    secureOption.SetProtocolChain(protocolVec);
71    TLSConnectOptions options;
72    MockOneWayParamOptions(address, secureOption, options);
73    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
74    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
75}
76
77HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
78{
79    if (!TlsUtilsTest::CheckCaFileExistence("bindInterface")) {
80        return;
81    }
82
83    TLSSocket oneWayService;
84    Socket::NetAddress address;
85    MockOneWayNetAddress(address);
86
87    oneWayService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
88}
89
90HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
91{
92    if (!TlsUtilsTest::CheckCaFileExistence("connectInterface")) {
93        return;
94    }
95
96    TLSSocket oneWayService;
97    SetOneWayHwTestShortParam(oneWayService);
98
99    const std::string data = "how do you do? this is connectInterface";
100    Socket::TCPSendOptions tcpSendOptions;
101    tcpSendOptions.SetData(data);
102    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
103    sleep(2);
104
105    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
106    sleep(2);
107}
108
109HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
110{
111    if (!TlsUtilsTest::CheckCaFileExistence("closeInterface")) {
112        return;
113    }
114
115    TLSSocket oneWayService;
116    SetOneWayHwTestShortParam(oneWayService);
117
118    const std::string data = "how do you do? this is closeInterface";
119    Socket::TCPSendOptions tcpSendOptions;
120    tcpSendOptions.SetData(data);
121
122    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
123    sleep(2);
124
125    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
126}
127
128HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
129{
130    if (!TlsUtilsTest::CheckCaFileExistence("sendInterface")) {
131        return;
132    }
133
134    TLSSocket oneWayService;
135    SetOneWayHwTestShortParam(oneWayService);
136
137    const std::string data = "how do you do? this is sendInterface";
138    Socket::TCPSendOptions tcpSendOptions;
139    tcpSendOptions.SetData(data);
140
141    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
142    sleep(2);
143
144    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
145}
146
147HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
148{
149    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteAddressInterface")) {
150        return;
151    }
152
153    TLSSocket oneWayService;
154    TLSConnectOptions options;
155    TLSSecureOptions secureOption;
156    Socket::NetAddress address;
157
158    std::vector<std::string> caVec = {TlsUtilsTest::ChangeToFile(CA_DER)};
159    secureOption.SetCaChain(caVec);
160    MockOneWayParamOptions(address, secureOption, options);
161
162    oneWayService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
163    oneWayService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
164
165    Socket::NetAddress netAddress;
166    oneWayService.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
167        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
168        netAddress.SetAddress(address.GetAddress());
169        netAddress.SetPort(address.GetPort());
170        netAddress.SetFamilyBySaFamily(address.GetSaFamily());
171    });
172    EXPECT_STREQ(netAddress.GetAddress().c_str(), TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
173    EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
174    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
175
176    const std::string data = "how do you do? this is getRemoteAddressInterface";
177    Socket::TCPSendOptions tcpSendOptions;
178    tcpSendOptions.SetData(data);
179
180    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
181
182    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
183}
184
185HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
186{
187    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteAddressInterface")) {
188        return;
189    }
190
191    TLSSocket oneWayService;
192    SetOneWayHwTestShortParam(oneWayService);
193
194    Socket::SocketStateBase TlsSocketstate;
195    oneWayService.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
196        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
197        TlsSocketstate = state;
198    });
199    std::cout << "TlsSocketTest TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
200    EXPECT_TRUE(TlsSocketstate.IsBound());
201    EXPECT_TRUE(!TlsSocketstate.IsClose());
202    EXPECT_TRUE(TlsSocketstate.IsConnected());
203
204    const std::string tlsSocketTestData = "how do you do? this is getStateInterface";
205    Socket::TCPSendOptions tcpSendOptions;
206    tcpSendOptions.SetData(tlsSocketTestData);
207    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
208
209    sleep(2);
210
211    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
212}
213
214HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
215{
216    if (!TlsUtilsTest::CheckCaFileExistence("getRemoteCertificateInterface")) {
217        return;
218    }
219    TLSSocket oneWayService;
220    SetOneWayHwTestShortParam(oneWayService);
221
222    const std::string data = "how do you do? This is UT test getRemoteCertificateInterface";
223    Socket::TCPSendOptions tcpSendOptions;
224    tcpSendOptions.SetData(data);
225
226    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
227
228    oneWayService.GetRemoteCertificate(
229        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
230
231    sleep(2);
232    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
233}
234
235HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
236{
237    if (!TlsUtilsTest::CheckCaFileExistence("protocolInterface")) {
238        return;
239    }
240    TLSSocket oneWayService;
241    SetOneWayHwTestLongParam(oneWayService);
242
243    const std::string data = "how do you do? this is protocolInterface";
244    Socket::TCPSendOptions tcpSendOptions;
245    tcpSendOptions.SetData(data);
246
247    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
248    std::string protocolVal;
249    oneWayService.GetProtocol([&protocolVal](int32_t errCode, const std::string &protocol) {
250        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
251        protocolVal = protocol;
252    });
253    EXPECT_STREQ(protocolVal.c_str(), "TLSv1.3");
254
255    Socket::SocketStateBase stateBase;
256    oneWayService.GetState([&stateBase](int32_t errCode, Socket::SocketStateBase state) {
257        if (errCode == TLSSOCKET_SUCCESS) {
258            EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
259            stateBase.SetIsBound(state.IsBound());
260            stateBase.SetIsClose(state.IsClose());
261            stateBase.SetIsConnected(state.IsConnected());
262        }
263    });
264    EXPECT_TRUE(stateBase.IsConnected());
265    sleep(2);
266
267    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
268}
269
270HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
271{
272    if (!TlsUtilsTest::CheckCaFileExistence("getCipherSuiteInterface")) {
273        return;
274    }
275    TLSSocket oneWayService;
276    SetOneWayHwTestLongParam(oneWayService);
277
278    bool flag = false;
279    const std::string data = "how do you do? This is getCipherSuiteInterface";
280    Socket::TCPSendOptions tcpSendOptions;
281    tcpSendOptions.SetData(data);
282    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
283
284    std::vector<std::string> testSuite;
285    oneWayService.GetCipherSuite([&testSuite](int32_t errCode, const std::vector<std::string> &suite) {
286        if (errCode == TLSSOCKET_SUCCESS) {
287            testSuite = suite;
288        }
289    });
290
291    for (auto const &iter : testSuite) {
292        if (iter == "AES256-SHA256") {
293            flag = true;
294        }
295    }
296
297    EXPECT_TRUE(flag);
298    sleep(2);
299
300    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
301}
302
303HWTEST_F(TlsSocketTest, onMessageDataInterface, testing::ext::TestSize.Level2)
304{
305    if (!TlsUtilsTest::CheckCaFileExistence("tlsSocketOnMessageData")) {
306        return;
307    }
308    std::string getData = "server->client";
309    TLSSocket oneWayService;
310    SetOneWayHwTestLongParam(oneWayService);
311
312    oneWayService.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
313        EXPECT_TRUE(data == getData);
314    });
315
316    const std::string data = "how do you do? this is tlsSocketOnMessageData";
317    Socket::TCPSendOptions tcpSendOptions;
318    tcpSendOptions.SetData(data);
319    oneWayService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
320
321    sleep(2);
322    (void)oneWayService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
323}
324} // namespace TlsSocket
325} // namespace NetStack
326} // namespace OHOS
327