1/*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <gtest/gtest.h>
17#include <iostream>
18
19#include <openssl/ssl.h>
20
21#define private public
22#include "accesstoken_kit.h"
23#include "tls_socket.h"
24#include "socket_remote_info.h"
25#include "token_setproc.h"
26#include "tls.h"
27#include "TlsTest.h"
28
29namespace OHOS {
30namespace NetStack {
31namespace TlsSocket {
32namespace {
33using namespace testing::ext;
34using namespace Security::AccessToken;
35using Security::AccessToken::AccessTokenID;
36static constexpr const char *KEY_PASS = "";
37static constexpr const char *PROTOCOL12 = "TLSv1.2";
38static constexpr const char *PROTOCOL13 = "TLSv1.3";
39static constexpr const char *IP_ADDRESS = "127.0.0.1";
40static constexpr const char *ALPN_PROTOCOL = "http/1.1";
41static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256";
42static constexpr const char *CIPHER_SUITE = "AES256-SHA256";
43static constexpr const char *SEND_DATA = "How do you do";
44static constexpr const char *SEND_DATA_EMPTY = "";
45static constexpr const size_t MAX_BUFFER_SIZE = 8192;
46const int PORT = 7838;
47const int SOCKET_FD = 5;
48const int SSL_ERROR_RETURN = -1;
49
50TLSConnectOptions BaseOption()
51{
52    TLSSecureOptions secureOption;
53    SecureData structureData(PRI_KEY_FILE);
54    secureOption.SetKey(structureData);
55    std::vector<std::string> caChain;
56    caChain.push_back(CA_CRT_FILE);
57    secureOption.SetCaChain(caChain);
58    secureOption.SetCert(CLIENT_FILE);
59    secureOption.SetCipherSuite(CIPHER_SUITE);
60    secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
61    std::vector<std::string> protocol;
62    protocol.push_back(PROTOCOL13);
63    secureOption.SetProtocolChain(protocol);
64
65    TLSConnectOptions connectOptions;
66    connectOptions.SetTlsSecureOptions(secureOption);
67    Socket::NetAddress netAddress;
68    netAddress.SetAddress(IP_ADDRESS);
69    netAddress.SetPort(0);
70    netAddress.SetFamilyBySaFamily(AF_INET);
71    connectOptions.SetNetAddress(netAddress);
72    std::vector<std::string> alpnProtocols;
73    alpnProtocols.push_back(ALPN_PROTOCOL);
74    connectOptions.SetAlpnProtocols(alpnProtocols);
75    return connectOptions;
76}
77
78HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
79                               .userID = 1,
80                               .instIndex = 0,
81                               .appIDDesc = "test",
82                               .isSystemApp = true};
83
84PermissionDef testPermDef = {
85    .permissionName = "ohos.permission.INTERNET",
86    .bundleName = "TlsSocketBranchTest",
87    .grantMode = 1,
88    .label = "label",
89    .labelId = 1,
90    .description = "Test Tls Socket Branch",
91    .descriptionId = 1,
92    .availableLevel = APL_SYSTEM_BASIC,
93};
94
95PermissionStateFull testState = {
96    .grantFlags = {2},
97    .grantStatus = {PermissionState::PERMISSION_GRANTED},
98    .isGeneral = true,
99    .permissionName = "ohos.permission.INTERNET",
100    .resDeviceID = {"local"},
101};
102
103HapPolicyParams testPolicyPrams = {
104    .apl = APL_SYSTEM_BASIC,
105    .domain = "test.domain",
106    .permList = {testPermDef},
107    .permStateList = {testState},
108};
109} // namespace
110
111class AccessToken {
112public:
113    AccessToken() : currentID_(GetSelfTokenID())
114    {
115        AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
116        accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
117        SetSelfTokenID(tokenIdEx.tokenIDEx);
118    }
119    ~AccessToken()
120    {
121        AccessTokenKit::DeleteToken(accessID_);
122        SetSelfTokenID(currentID_);
123    }
124
125private:
126    AccessTokenID currentID_;
127    AccessTokenID accessID_ = 0;
128};
129
130class TlsSocketBranchTest : public testing::Test {
131public:
132    static void SetUpTestCase() {}
133
134    static void TearDownTestCase() {}
135
136    virtual void SetUp() {}
137
138    virtual void TearDown() {}
139};
140
141HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)
142{
143    TLSSecureOptions secureOption;
144    SecureData structureData(PRI_KEY_FILE);
145    secureOption.SetKey(structureData);
146
147    SecureData keyPass(KEY_PASS);
148    secureOption.SetKeyPass(keyPass);
149    SecureData secureData = secureOption.GetKey();
150    EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE));
151    std::vector<std::string> caChain;
152    caChain.push_back(CA_CRT_FILE);
153    secureOption.SetCaChain(caChain);
154    std::vector<std::string> getCaChain = secureOption.GetCaChain();
155    EXPECT_NE(getCaChain.data(), nullptr);
156
157    secureOption.SetCert(CLIENT_FILE);
158    std::string getCert = secureOption.GetCert();
159    EXPECT_NE(getCert.data(), nullptr);
160
161    std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13};
162    secureOption.SetProtocolChain(protocolVec);
163    std::vector<std::string> getProtocol;
164    getProtocol = secureOption.GetProtocolChain();
165
166    TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
167    TLSSecureOptions equalOption = secureOption;
168}
169
170HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)
171{
172    TLSSecureOptions secureOption;
173    secureOption.SetUseRemoteCipherPrefer(false);
174    bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer();
175    EXPECT_FALSE(isUseRemoteCipher);
176
177    secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
178    std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms();
179    EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM);
180
181    secureOption.SetCipherSuite(CIPHER_SUITE);
182    std::string getCipherSuite = secureOption.GetCipherSuite();
183    EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE);
184
185    TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
186    TLSSecureOptions equalOption = secureOption;
187
188    TLSConnectOptions connectOptions;
189    connectOptions.SetTlsSecureOptions(secureOption);
190}
191
192HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)
193{
194    TLSSecureOptions secureOption;
195    TLSConnectOptions connectOptions;
196    connectOptions.SetTlsSecureOptions(secureOption);
197
198    Socket::NetAddress netAddress;
199    netAddress.SetAddress(IP_ADDRESS);
200    netAddress.SetPort(PORT);
201    connectOptions.SetNetAddress(netAddress);
202    Socket::NetAddress getNetAddress = connectOptions.GetNetAddress();
203    std::string address = getNetAddress.GetAddress();
204    EXPECT_STREQ(IP_ADDRESS, address.data());
205    int port = getNetAddress.GetPort();
206    EXPECT_EQ(port, PORT);
207    netAddress.SetFamilyBySaFamily(AF_INET6);
208    sa_family_t getFamily = netAddress.GetSaFamily();
209    EXPECT_EQ(getFamily, AF_INET6);
210
211    std::vector<std::string> alpnProtocols;
212    alpnProtocols.push_back(ALPN_PROTOCOL);
213    connectOptions.SetAlpnProtocols(alpnProtocols);
214    std::vector<std::string> getAlpnProtocols;
215    getAlpnProtocols = connectOptions.GetAlpnProtocols();
216    EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data());
217}
218
219HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)
220{
221    TLSSecureOptions secureOption;
222    SecureData structureData(PRI_KEY_FILE);
223    secureOption.SetKey(structureData);
224    std::vector<std::string> caChain;
225    caChain.push_back(CA_CRT_FILE);
226    secureOption.SetCaChain(caChain);
227    secureOption.SetCert(CLIENT_FILE);
228
229    TLSConnectOptions connectOptions;
230    connectOptions.SetTlsSecureOptions(secureOption);
231
232    Socket::NetAddress netAddress;
233    netAddress.SetAddress(IP_ADDRESS);
234    netAddress.SetPort(0);
235    netAddress.SetFamilyBySaFamily(AF_INET);
236    EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
237}
238
239HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)
240{
241    TLSConnectOptions tlsConnectOptions = BaseOption();
242
243    AccessToken token;
244    TLSSocket tlsSocket;
245    tlsSocket.OnError(
246        [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); });
247    tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); });
248    std::string getData;
249    tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
250        EXPECT_STREQ(getData.data(), nullptr);
251    });
252    const std::string data = "how do you do?";
253    Socket::TCPSendOptions tcpSendOptions;
254    tcpSendOptions.SetData(data);
255    tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
256    tlsSocket.GetSignatureAlgorithms(
257        [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
258    tlsSocket.GetCertificate(
259        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); });
260    tlsSocket.GetCipherSuite(
261        [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
262    tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); });
263    tlsSocket.GetRemoteCertificate(
264        [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
265    (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); });
266}
267
268HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2)
269{
270    TLSConnectOptions connectOptions = BaseOption();
271
272    TLSSocket tlsSocket;
273    TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
274    bool isConnectToHost = tlsSocketInternal->TlsConnectToHost(SOCKET_FD, connectOptions, false);
275    EXPECT_FALSE(isConnectToHost);
276    tlsSocketInternal->SetTlsConfiguration(connectOptions);
277
278    bool sendSslNull = tlsSocketInternal->Send(SEND_DATA);
279    EXPECT_FALSE(sendSslNull);
280    char buffer[MAX_BUFFER_SIZE];
281    bzero(buffer, MAX_BUFFER_SIZE);
282    int recvSslNull = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
283    EXPECT_EQ(recvSslNull, SSL_ERROR_RETURN);
284    bool closeSslNull = tlsSocketInternal->Close();
285    EXPECT_FALSE(closeSslNull);
286    tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
287    bool sendEmpty = tlsSocketInternal->Send(SEND_DATA_EMPTY);
288    EXPECT_TRUE(sendEmpty);
289    int recv = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
290    EXPECT_EQ(recv, SSL_ERROR_RETURN);
291    bool close = tlsSocketInternal->Close();
292    EXPECT_FALSE(close);
293    delete tlsSocketInternal;
294}
295
296HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)
297{
298    TLSSocket tlsSocket;
299    TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
300
301    std::vector<std::string> alpnProtocols;
302    alpnProtocols.push_back(ALPN_PROTOCOL);
303    bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
304    EXPECT_FALSE(alpnProSslNull);
305    std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite();
306    EXPECT_EQ(getCipherSuite.size(), 0);
307    bool setSharedSigals = tlsSocketInternal->SetSharedSigals();
308    EXPECT_FALSE(setSharedSigals);
309    tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
310    getCipherSuite = tlsSocketInternal->GetCipherSuite();
311    EXPECT_NE(getCipherSuite.size(), 0);
312    setSharedSigals = tlsSocketInternal->SetSharedSigals();
313    EXPECT_FALSE(setSharedSigals);
314    TLSConnectOptions connectOptions = BaseOption();
315    bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
316    EXPECT_TRUE(alpnPro);
317
318    Socket::SocketRemoteInfo remoteInfo;
319    tlsSocketInternal->hostName_ = IP_ADDRESS;
320    tlsSocketInternal->port_ = PORT;
321    tlsSocketInternal->family_ = AF_INET;
322    tlsSocketInternal->MakeRemoteInfo(remoteInfo);
323    getCipherSuite = tlsSocketInternal->GetCipherSuite();
324    EXPECT_NE(getCipherSuite.size(), 0);
325
326    std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate();
327    EXPECT_EQ(getRemoteCert, "");
328
329    std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms();
330    EXPECT_EQ(getSignatureAlgorithms.size(), 0);
331
332    std::string getProtocol = tlsSocketInternal->GetProtocol();
333    EXPECT_NE(getProtocol, "");
334
335    setSharedSigals = tlsSocketInternal->SetSharedSigals();
336    EXPECT_FALSE(setSharedSigals);
337
338    ssl_st *ssl = tlsSocketInternal->GetSSL();
339    EXPECT_NE(ssl, nullptr);
340    delete tlsSocketInternal;
341}
342} // namespace TlsSocket
343} // namespace NetStack
344} // namespace OHOS
345