1/*
2 * Copyright (c) 2022-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "accesstoken_kit.h"
17#include "net_address.h"
18#include "secure_data.h"
19#include "socket_error.h"
20#include "socket_state_base.h"
21#include "tls.h"
22#include "tls_certificate.h"
23#include "tls_configuration.h"
24#include "tls_key.h"
25#include "tls_socket.h"
26#include "tls_utils_test.h"
27#include "token_setproc.h"
28
29namespace OHOS {
30namespace NetStack {
31namespace TlsSocket {
32namespace {
33using namespace testing::ext;
34using namespace Security::AccessToken;
35} // namespace
36
37void MockConnectionNetAddress(Socket::NetAddress &address)
38{
39    address.SetAddress(TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)));
40    address.SetPort(std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
41    address.SetFamilyBySaFamily(AF_INET);
42}
43
44void MockConnectionParamOptions(Socket::NetAddress &address, TLSSecureOptions &secureOption, TLSConnectOptions &options)
45{
46    secureOption.SetKey(SecureData(TlsUtilsTest::ChangeToFile(PRIVATE_KEY_PEM_CHAIN)));
47    secureOption.SetCert(TlsUtilsTest::ChangeToFile(CLIENT_CRT_CHAIN));
48
49    MockConnectionNetAddress(address);
50    options.SetTlsSecureOptions(secureOption);
51    options.SetNetAddress(address);
52}
53
54void SetUnilateralHwTestShortParam(TLSSocket &server)
55{
56    TLSConnectOptions options;
57    TLSSecureOptions secureOption;
58    Socket::NetAddress address;
59    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(ROOT_CA_PATH_CHAIN),
60        TlsUtilsTest::ChangeToFile(MID_CA_CHAIN) };
61    secureOption.SetCaChain(caVec);
62    MockConnectionParamOptions(address, secureOption, options);
63
64    server.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
65    server.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
66}
67
68HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
69                               .userID = 1,
70                               .instIndex = 0,
71                               .appIDDesc = "test",
72                               .isSystemApp = true};
73
74PermissionDef testPermDef = {
75    .permissionName = "ohos.permission.INTERNET",
76    .bundleName = "TlsSocketBranchTest",
77    .grantMode = 1,
78    .label = "label",
79    .labelId = 1,
80    .description = "Test Tls Socket Branch",
81    .descriptionId = 1,
82    .availableLevel = APL_SYSTEM_BASIC,
83};
84
85PermissionStateFull testState = {
86    .grantFlags = {2},
87    .grantStatus = {PermissionState::PERMISSION_GRANTED},
88    .isGeneral = true,
89    .permissionName = "ohos.permission.INTERNET",
90    .resDeviceID = {"local"},
91};
92
93HapPolicyParams testPolicyPrams = {
94    .apl = APL_SYSTEM_BASIC,
95    .domain = "test.domain",
96    .permList = {testPermDef},
97    .permStateList = {testState},
98};
99
100class AccessToken {
101public:
102    AccessToken() : currentID_(GetSelfTokenID())
103    {
104        AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
105        accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
106        SetSelfTokenID(tokenIdEx.tokenIDEx);
107    }
108    ~AccessToken()
109    {
110        AccessTokenKit::DeleteToken(accessID_);
111        SetSelfTokenID(currentID_);
112    }
113
114private:
115    AccessTokenID currentID_;
116    AccessTokenID accessID_ = 0;
117};
118
119class TlsSocketBranchTest : public testing::Test {
120public:
121    static void SetUpTestCase() {}
122
123    static void TearDownTestCase() {}
124
125    virtual void SetUp() {}
126
127    virtual void TearDown() {}
128};
129
130HWTEST_F(TlsSocketTest, bindInterface, testing::ext::TestSize.Level2)
131{
132    if (!TlsUtilsTest::CheckCaPathChainExistence("bindInterface")) {
133        return;
134    }
135
136    TLSSocket tlsService;
137    Socket::NetAddress address;
138    MockConnectionNetAddress(address);
139
140    AccessToken token;
141    tlsService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
142}
143
144HWTEST_F(TlsSocketTest, connectInterface, testing::ext::TestSize.Level2)
145{
146    if (!TlsUtilsTest::CheckCaPathChainExistence("connectInterface")) {
147        return;
148    }
149    TLSSocket tlsService;
150    SetUnilateralHwTestShortParam(tlsService);
151
152    AccessToken token;
153    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
154    Socket::TCPSendOptions tcpSendOptions;
155    tcpSendOptions.SetData(data);
156    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
157
158    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
159}
160
161HWTEST_F(TlsSocketTest, closeInterface, testing::ext::TestSize.Level2)
162{
163    if (!TlsUtilsTest::CheckCaPathChainExistence("closeInterface")) {
164        return;
165    }
166    TLSSocket tlsService;
167    SetUnilateralHwTestShortParam(tlsService);
168
169    AccessToken token;
170    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
171    ;
172    Socket::TCPSendOptions tcpSendOptions;
173    tcpSendOptions.SetData(data);
174
175    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
176
177    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
178}
179
180HWTEST_F(TlsSocketTest, sendInterface, testing::ext::TestSize.Level2)
181{
182    if (!TlsUtilsTest::CheckCaPathChainExistence("sendInterface")) {
183        return;
184    }
185    TLSSocket tlsService;
186    SetUnilateralHwTestShortParam(tlsService);
187
188    AccessToken token;
189    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
190    Socket::TCPSendOptions tcpSendOptions;
191    tcpSendOptions.SetData(data);
192
193    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
194
195    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
196}
197
198HWTEST_F(TlsSocketTest, getRemoteAddressInterface, testing::ext::TestSize.Level2)
199{
200    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
201        return;
202    }
203    TLSSocket tlsService;
204    TLSConnectOptions options;
205    TLSSecureOptions secureOption;
206    Socket::NetAddress address;
207    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(ROOT_CA_PATH_CHAIN),
208        TlsUtilsTest::ChangeToFile(MID_CA_CHAIN) };
209    secureOption.SetCaChain(caVec);
210    MockConnectionParamOptions(address, secureOption, options);
211
212    tlsService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
213    tlsService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
214
215    AccessToken token;
216    Socket::NetAddress netAddress;
217    tlsService.GetRemoteAddress([&netAddress](int32_t errCode, const Socket::NetAddress &address) {
218        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
219        netAddress.SetFamilyBySaFamily(address.GetSaFamily());
220        netAddress.SetAddress(address.GetAddress());
221        netAddress.SetPort(address.GetPort());
222    });
223    EXPECT_STREQ(netAddress.GetAddress().c_str(), TlsUtilsTest::GetIp(TlsUtilsTest::ChangeToFile(IP_ADDRESS)).c_str());
224    EXPECT_EQ(address.GetPort(), std::atoi(TlsUtilsTest::ChangeToFile(PORT).c_str()));
225    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
226
227    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
228    Socket::TCPSendOptions tcpSendOptions;
229    tcpSendOptions.SetData(data);
230
231    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
232
233    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
234}
235
236HWTEST_F(TlsSocketTest, getStateInterface, testing::ext::TestSize.Level2)
237{
238    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteAddressInterface")) {
239        return;
240    }
241
242    TLSSocket tlsService;
243    SetUnilateralHwTestShortParam(tlsService);
244
245    AccessToken token;
246    Socket::SocketStateBase TlsSocketstate;
247    tlsService.GetState([&TlsSocketstate](int32_t errCode, const Socket::SocketStateBase &state) {
248        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
249        TlsSocketstate = state;
250    });
251    std::cout << "TlsSocketUnilateralConnection TlsSocketstate.IsClose(): " << TlsSocketstate.IsClose() << std::endl;
252    EXPECT_TRUE(TlsSocketstate.IsBound());
253    EXPECT_TRUE(!TlsSocketstate.IsClose());
254    EXPECT_TRUE(TlsSocketstate.IsConnected());
255
256    const std::string connectionData = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
257    Socket::TCPSendOptions tcpSendOptions;
258    tcpSendOptions.SetData(connectionData);
259    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
260
261    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
262}
263
264HWTEST_F(TlsSocketTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2)
265{
266    if (!TlsUtilsTest::CheckCaPathChainExistence("getRemoteCertificateInterface")) {
267        return;
268    }
269    TLSSocket tlsService;
270    SetUnilateralHwTestShortParam(tlsService);
271    Socket::TCPSendOptions tcpSendOptions;
272    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
273
274    AccessToken token;
275    tcpSendOptions.SetData(data);
276
277    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
278
279    tlsService.GetRemoteCertificate(
280        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
281
282    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
283}
284
285HWTEST_F(TlsSocketTest, protocolInterface, testing::ext::TestSize.Level2)
286{
287    if (!TlsUtilsTest::CheckCaPathChainExistence("protocolInterface")) {
288        return;
289    }
290
291    TLSSocket tlsService;
292    TLSConnectOptions options;
293    TLSSecureOptions secureOption;
294    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(ROOT_CA_PATH_CHAIN),
295        TlsUtilsTest::ChangeToFile(MID_CA_CHAIN) };
296    secureOption.SetCaChain(caVec);
297    std::string protocolV13 = "TLSv1.2";
298    std::vector<std::string> protocolVec = { protocolV13 };
299    secureOption.SetProtocolChain(protocolVec);
300    Socket::NetAddress address;
301    MockConnectionParamOptions(address, secureOption, options);
302
303    AccessToken token;
304    tlsService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
305
306    tlsService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
307
308    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
309    Socket::TCPSendOptions tcpSendOptions;
310    tcpSendOptions.SetData(data);
311
312    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
313    std::string getProtocolVal;
314    tlsService.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) {
315        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
316        getProtocolVal = protocol;
317    });
318    EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.2");
319
320    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
321}
322
323HWTEST_F(TlsSocketTest, getCipherSuiteInterface, testing::ext::TestSize.Level2)
324{
325    if (!TlsUtilsTest::CheckCaPathChainExistence("getCipherSuiteInterface")) {
326        return;
327    }
328
329    TLSConnectOptions options;
330    TLSSocket tlsService;
331    TLSSecureOptions secureOption;
332    std::vector<std::string> caVec = { TlsUtilsTest::ChangeToFile(ROOT_CA_PATH_CHAIN),
333        TlsUtilsTest::ChangeToFile(MID_CA_CHAIN) };
334    secureOption.SetCaChain(caVec);
335    secureOption.SetCipherSuite("ECDHE-RSA-AES128-GCM-SHA256");
336    Socket::NetAddress address;
337    MockConnectionParamOptions(address, secureOption, options);
338
339    bool flag = false;
340    AccessToken token;
341    tlsService.Bind(address, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
342    tlsService.Connect(options, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
343
344    const std::string data = "GET / HTTP/1.1\r\nHost: www.baidu.com\r\nConnection: keep-alive\r\n\r\n";
345    Socket::TCPSendOptions tcpSendOptions;
346    tcpSendOptions.SetData(data);
347    tlsService.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
348
349    std::vector<std::string> cipherSuite;
350    tlsService.GetCipherSuite([&cipherSuite](int32_t errCode, const std::vector<std::string> &suite) {
351        EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS);
352        cipherSuite = suite;
353    });
354
355    for (auto const &iter : cipherSuite) {
356        if (iter == "ECDHE-RSA-AES128-GCM-SHA256") {
357            flag = true;
358        }
359    }
360
361    EXPECT_TRUE(flag);
362
363    (void)tlsService.Close([](int32_t errCode) { EXPECT_TRUE(errCode == TLSSOCKET_SUCCESS); });
364}
365} // namespace TlsSocket
366} // namespace NetStack
367} // namespace OHOS
368
369