1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <gtest/gtest.h>
17 #include <iostream>
18 
19 #include <openssl/ssl.h>
20 
21 #define private public
22 #include "accesstoken_kit.h"
23 #include "tls_socket.h"
24 #include "socket_remote_info.h"
25 #include "token_setproc.h"
26 #include "tls.h"
27 #include "TlsTest.h"
28 
29 namespace OHOS {
30 namespace NetStack {
31 namespace TlsSocket {
32 namespace {
33 using namespace testing::ext;
34 using namespace Security::AccessToken;
35 using Security::AccessToken::AccessTokenID;
36 static constexpr const char *KEY_PASS = "";
37 static constexpr const char *PROTOCOL12 = "TLSv1.2";
38 static constexpr const char *PROTOCOL13 = "TLSv1.3";
39 static constexpr const char *IP_ADDRESS = "127.0.0.1";
40 static constexpr const char *ALPN_PROTOCOL = "http/1.1";
41 static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256";
42 static constexpr const char *CIPHER_SUITE = "AES256-SHA256";
43 static constexpr const char *SEND_DATA = "How do you do";
44 static constexpr const char *SEND_DATA_EMPTY = "";
45 static constexpr const size_t MAX_BUFFER_SIZE = 8192;
46 const int PORT = 7838;
47 const int SOCKET_FD = 5;
48 const int SSL_ERROR_RETURN = -1;
49 
BaseOption()50 TLSConnectOptions BaseOption()
51 {
52     TLSSecureOptions secureOption;
53     SecureData structureData(PRI_KEY_FILE);
54     secureOption.SetKey(structureData);
55     std::vector<std::string> caChain;
56     caChain.push_back(CA_CRT_FILE);
57     secureOption.SetCaChain(caChain);
58     secureOption.SetCert(CLIENT_FILE);
59     secureOption.SetCipherSuite(CIPHER_SUITE);
60     secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
61     std::vector<std::string> protocol;
62     protocol.push_back(PROTOCOL13);
63     secureOption.SetProtocolChain(protocol);
64 
65     TLSConnectOptions connectOptions;
66     connectOptions.SetTlsSecureOptions(secureOption);
67     Socket::NetAddress netAddress;
68     netAddress.SetAddress(IP_ADDRESS);
69     netAddress.SetPort(0);
70     netAddress.SetFamilyBySaFamily(AF_INET);
71     connectOptions.SetNetAddress(netAddress);
72     std::vector<std::string> alpnProtocols;
73     alpnProtocols.push_back(ALPN_PROTOCOL);
74     connectOptions.SetAlpnProtocols(alpnProtocols);
75     return connectOptions;
76 }
77 
78 HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
79                                .userID = 1,
80                                .instIndex = 0,
81                                .appIDDesc = "test",
82                                .isSystemApp = true};
83 
84 PermissionDef testPermDef = {
85     .permissionName = "ohos.permission.INTERNET",
86     .bundleName = "TlsSocketBranchTest",
87     .grantMode = 1,
88     .label = "label",
89     .labelId = 1,
90     .description = "Test Tls Socket Branch",
91     .descriptionId = 1,
92     .availableLevel = APL_SYSTEM_BASIC,
93 };
94 
95 PermissionStateFull testState = {
96     .grantFlags = {2},
97     .grantStatus = {PermissionState::PERMISSION_GRANTED},
98     .isGeneral = true,
99     .permissionName = "ohos.permission.INTERNET",
100     .resDeviceID = {"local"},
101 };
102 
103 HapPolicyParams testPolicyPrams = {
104     .apl = APL_SYSTEM_BASIC,
105     .domain = "test.domain",
106     .permList = {testPermDef},
107     .permStateList = {testState},
108 };
109 } // namespace
110 
111 class AccessToken {
112 public:
AccessToken()113     AccessToken() : currentID_(GetSelfTokenID())
114     {
115         AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
116         accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
117         SetSelfTokenID(tokenIdEx.tokenIDEx);
118     }
~AccessToken()119     ~AccessToken()
120     {
121         AccessTokenKit::DeleteToken(accessID_);
122         SetSelfTokenID(currentID_);
123     }
124 
125 private:
126     AccessTokenID currentID_;
127     AccessTokenID accessID_ = 0;
128 };
129 
130 class TlsSocketBranchTest : public testing::Test {
131 public:
SetUpTestCase()132     static void SetUpTestCase() {}
133 
TearDownTestCase()134     static void TearDownTestCase() {}
135 
SetUp()136     virtual void SetUp() {}
137 
TearDown()138     virtual void TearDown() {}
139 };
140 
HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)141 HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)
142 {
143     TLSSecureOptions secureOption;
144     SecureData structureData(PRI_KEY_FILE);
145     secureOption.SetKey(structureData);
146 
147     SecureData keyPass(KEY_PASS);
148     secureOption.SetKeyPass(keyPass);
149     SecureData secureData = secureOption.GetKey();
150     EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE));
151     std::vector<std::string> caChain;
152     caChain.push_back(CA_CRT_FILE);
153     secureOption.SetCaChain(caChain);
154     std::vector<std::string> getCaChain = secureOption.GetCaChain();
155     EXPECT_NE(getCaChain.data(), nullptr);
156 
157     secureOption.SetCert(CLIENT_FILE);
158     std::string getCert = secureOption.GetCert();
159     EXPECT_NE(getCert.data(), nullptr);
160 
161     std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13};
162     secureOption.SetProtocolChain(protocolVec);
163     std::vector<std::string> getProtocol;
164     getProtocol = secureOption.GetProtocolChain();
165 
166     TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
167     TLSSecureOptions equalOption = secureOption;
168 }
169 
HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)170 HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)
171 {
172     TLSSecureOptions secureOption;
173     secureOption.SetUseRemoteCipherPrefer(false);
174     bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer();
175     EXPECT_FALSE(isUseRemoteCipher);
176 
177     secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
178     std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms();
179     EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM);
180 
181     secureOption.SetCipherSuite(CIPHER_SUITE);
182     std::string getCipherSuite = secureOption.GetCipherSuite();
183     EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE);
184 
185     TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
186     TLSSecureOptions equalOption = secureOption;
187 
188     TLSConnectOptions connectOptions;
189     connectOptions.SetTlsSecureOptions(secureOption);
190 }
191 
HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)192 HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)
193 {
194     TLSSecureOptions secureOption;
195     TLSConnectOptions connectOptions;
196     connectOptions.SetTlsSecureOptions(secureOption);
197 
198     Socket::NetAddress netAddress;
199     netAddress.SetAddress(IP_ADDRESS);
200     netAddress.SetPort(PORT);
201     connectOptions.SetNetAddress(netAddress);
202     Socket::NetAddress getNetAddress = connectOptions.GetNetAddress();
203     std::string address = getNetAddress.GetAddress();
204     EXPECT_STREQ(IP_ADDRESS, address.data());
205     int port = getNetAddress.GetPort();
206     EXPECT_EQ(port, PORT);
207     netAddress.SetFamilyBySaFamily(AF_INET6);
208     sa_family_t getFamily = netAddress.GetSaFamily();
209     EXPECT_EQ(getFamily, AF_INET6);
210 
211     std::vector<std::string> alpnProtocols;
212     alpnProtocols.push_back(ALPN_PROTOCOL);
213     connectOptions.SetAlpnProtocols(alpnProtocols);
214     std::vector<std::string> getAlpnProtocols;
215     getAlpnProtocols = connectOptions.GetAlpnProtocols();
216     EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data());
217 }
218 
HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)219 HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)
220 {
221     TLSSecureOptions secureOption;
222     SecureData structureData(PRI_KEY_FILE);
223     secureOption.SetKey(structureData);
224     std::vector<std::string> caChain;
225     caChain.push_back(CA_CRT_FILE);
226     secureOption.SetCaChain(caChain);
227     secureOption.SetCert(CLIENT_FILE);
228 
229     TLSConnectOptions connectOptions;
230     connectOptions.SetTlsSecureOptions(secureOption);
231 
232     Socket::NetAddress netAddress;
233     netAddress.SetAddress(IP_ADDRESS);
234     netAddress.SetPort(0);
235     netAddress.SetFamilyBySaFamily(AF_INET);
236     EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
237 }
238 
HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)239 HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)
240 {
241     TLSConnectOptions tlsConnectOptions = BaseOption();
242 
243     AccessToken token;
244     TLSSocket tlsSocket;
245     tlsSocket.OnError(
246         [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); });
247     tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); });
248     std::string getData;
249     tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
250         EXPECT_STREQ(getData.data(), nullptr);
251     });
252     const std::string data = "how do you do?";
253     Socket::TCPSendOptions tcpSendOptions;
254     tcpSendOptions.SetData(data);
255     tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
256     tlsSocket.GetSignatureAlgorithms(
257         [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
258     tlsSocket.GetCertificate(
259         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); });
260     tlsSocket.GetCipherSuite(
261         [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
262     tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); });
263     tlsSocket.GetRemoteCertificate(
264         [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
265     (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); });
266 }
267 
HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2)268 HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2)
269 {
270     TLSConnectOptions connectOptions = BaseOption();
271 
272     TLSSocket tlsSocket;
273     TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
274     bool isConnectToHost = tlsSocketInternal->TlsConnectToHost(SOCKET_FD, connectOptions, false);
275     EXPECT_FALSE(isConnectToHost);
276     tlsSocketInternal->SetTlsConfiguration(connectOptions);
277 
278     bool sendSslNull = tlsSocketInternal->Send(SEND_DATA);
279     EXPECT_FALSE(sendSslNull);
280     char buffer[MAX_BUFFER_SIZE];
281     bzero(buffer, MAX_BUFFER_SIZE);
282     int recvSslNull = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
283     EXPECT_EQ(recvSslNull, SSL_ERROR_RETURN);
284     bool closeSslNull = tlsSocketInternal->Close();
285     EXPECT_FALSE(closeSslNull);
286     tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
287     bool sendEmpty = tlsSocketInternal->Send(SEND_DATA_EMPTY);
288     EXPECT_TRUE(sendEmpty);
289     int recv = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
290     EXPECT_EQ(recv, SSL_ERROR_RETURN);
291     bool close = tlsSocketInternal->Close();
292     EXPECT_FALSE(close);
293     delete tlsSocketInternal;
294 }
295 
HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)296 HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)
297 {
298     TLSSocket tlsSocket;
299     TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
300 
301     std::vector<std::string> alpnProtocols;
302     alpnProtocols.push_back(ALPN_PROTOCOL);
303     bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
304     EXPECT_FALSE(alpnProSslNull);
305     std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite();
306     EXPECT_EQ(getCipherSuite.size(), 0);
307     bool setSharedSigals = tlsSocketInternal->SetSharedSigals();
308     EXPECT_FALSE(setSharedSigals);
309     tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
310     getCipherSuite = tlsSocketInternal->GetCipherSuite();
311     EXPECT_NE(getCipherSuite.size(), 0);
312     setSharedSigals = tlsSocketInternal->SetSharedSigals();
313     EXPECT_FALSE(setSharedSigals);
314     TLSConnectOptions connectOptions = BaseOption();
315     bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
316     EXPECT_TRUE(alpnPro);
317 
318     Socket::SocketRemoteInfo remoteInfo;
319     tlsSocketInternal->hostName_ = IP_ADDRESS;
320     tlsSocketInternal->port_ = PORT;
321     tlsSocketInternal->family_ = AF_INET;
322     tlsSocketInternal->MakeRemoteInfo(remoteInfo);
323     getCipherSuite = tlsSocketInternal->GetCipherSuite();
324     EXPECT_NE(getCipherSuite.size(), 0);
325 
326     std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate();
327     EXPECT_EQ(getRemoteCert, "");
328 
329     std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms();
330     EXPECT_EQ(getSignatureAlgorithms.size(), 0);
331 
332     std::string getProtocol = tlsSocketInternal->GetProtocol();
333     EXPECT_NE(getProtocol, "");
334 
335     setSharedSigals = tlsSocketInternal->SetSharedSigals();
336     EXPECT_FALSE(setSharedSigals);
337 
338     ssl_st *ssl = tlsSocketInternal->GetSSL();
339     EXPECT_NE(ssl, nullptr);
340     delete tlsSocketInternal;
341 }
342 } // namespace TlsSocket
343 } // namespace NetStack
344 } // namespace OHOS
345