1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include <gtest/gtest.h>
17 #include <iostream>
18
19 #include <openssl/ssl.h>
20
21 #define private public
22 #include "accesstoken_kit.h"
23 #include "tls_socket.h"
24 #include "socket_remote_info.h"
25 #include "token_setproc.h"
26 #include "tls.h"
27 #include "TlsTest.h"
28
29 namespace OHOS {
30 namespace NetStack {
31 namespace TlsSocket {
32 namespace {
33 using namespace testing::ext;
34 using namespace Security::AccessToken;
35 using Security::AccessToken::AccessTokenID;
36 static constexpr const char *KEY_PASS = "";
37 static constexpr const char *PROTOCOL12 = "TLSv1.2";
38 static constexpr const char *PROTOCOL13 = "TLSv1.3";
39 static constexpr const char *IP_ADDRESS = "127.0.0.1";
40 static constexpr const char *ALPN_PROTOCOL = "http/1.1";
41 static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256";
42 static constexpr const char *CIPHER_SUITE = "AES256-SHA256";
43 static constexpr const char *SEND_DATA = "How do you do";
44 static constexpr const char *SEND_DATA_EMPTY = "";
45 static constexpr const size_t MAX_BUFFER_SIZE = 8192;
46 const int PORT = 7838;
47 const int SOCKET_FD = 5;
48 const int SSL_ERROR_RETURN = -1;
49
BaseOption()50 TLSConnectOptions BaseOption()
51 {
52 TLSSecureOptions secureOption;
53 SecureData structureData(PRI_KEY_FILE);
54 secureOption.SetKey(structureData);
55 std::vector<std::string> caChain;
56 caChain.push_back(CA_CRT_FILE);
57 secureOption.SetCaChain(caChain);
58 secureOption.SetCert(CLIENT_FILE);
59 secureOption.SetCipherSuite(CIPHER_SUITE);
60 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
61 std::vector<std::string> protocol;
62 protocol.push_back(PROTOCOL13);
63 secureOption.SetProtocolChain(protocol);
64
65 TLSConnectOptions connectOptions;
66 connectOptions.SetTlsSecureOptions(secureOption);
67 Socket::NetAddress netAddress;
68 netAddress.SetAddress(IP_ADDRESS);
69 netAddress.SetPort(0);
70 netAddress.SetFamilyBySaFamily(AF_INET);
71 connectOptions.SetNetAddress(netAddress);
72 std::vector<std::string> alpnProtocols;
73 alpnProtocols.push_back(ALPN_PROTOCOL);
74 connectOptions.SetAlpnProtocols(alpnProtocols);
75 return connectOptions;
76 }
77
78 HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest",
79 .userID = 1,
80 .instIndex = 0,
81 .appIDDesc = "test",
82 .isSystemApp = true};
83
84 PermissionDef testPermDef = {
85 .permissionName = "ohos.permission.INTERNET",
86 .bundleName = "TlsSocketBranchTest",
87 .grantMode = 1,
88 .label = "label",
89 .labelId = 1,
90 .description = "Test Tls Socket Branch",
91 .descriptionId = 1,
92 .availableLevel = APL_SYSTEM_BASIC,
93 };
94
95 PermissionStateFull testState = {
96 .grantFlags = {2},
97 .grantStatus = {PermissionState::PERMISSION_GRANTED},
98 .isGeneral = true,
99 .permissionName = "ohos.permission.INTERNET",
100 .resDeviceID = {"local"},
101 };
102
103 HapPolicyParams testPolicyPrams = {
104 .apl = APL_SYSTEM_BASIC,
105 .domain = "test.domain",
106 .permList = {testPermDef},
107 .permStateList = {testState},
108 };
109 } // namespace
110
111 class AccessToken {
112 public:
AccessToken()113 AccessToken() : currentID_(GetSelfTokenID())
114 {
115 AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams);
116 accessID_ = tokenIdEx.tokenIdExStruct.tokenID;
117 SetSelfTokenID(tokenIdEx.tokenIDEx);
118 }
~AccessToken()119 ~AccessToken()
120 {
121 AccessTokenKit::DeleteToken(accessID_);
122 SetSelfTokenID(currentID_);
123 }
124
125 private:
126 AccessTokenID currentID_;
127 AccessTokenID accessID_ = 0;
128 };
129
130 class TlsSocketBranchTest : public testing::Test {
131 public:
SetUpTestCase()132 static void SetUpTestCase() {}
133
TearDownTestCase()134 static void TearDownTestCase() {}
135
SetUp()136 virtual void SetUp() {}
137
TearDown()138 virtual void TearDown() {}
139 };
140
HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)141 HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2)
142 {
143 TLSSecureOptions secureOption;
144 SecureData structureData(PRI_KEY_FILE);
145 secureOption.SetKey(structureData);
146
147 SecureData keyPass(KEY_PASS);
148 secureOption.SetKeyPass(keyPass);
149 SecureData secureData = secureOption.GetKey();
150 EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE));
151 std::vector<std::string> caChain;
152 caChain.push_back(CA_CRT_FILE);
153 secureOption.SetCaChain(caChain);
154 std::vector<std::string> getCaChain = secureOption.GetCaChain();
155 EXPECT_NE(getCaChain.data(), nullptr);
156
157 secureOption.SetCert(CLIENT_FILE);
158 std::string getCert = secureOption.GetCert();
159 EXPECT_NE(getCert.data(), nullptr);
160
161 std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13};
162 secureOption.SetProtocolChain(protocolVec);
163 std::vector<std::string> getProtocol;
164 getProtocol = secureOption.GetProtocolChain();
165
166 TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
167 TLSSecureOptions equalOption = secureOption;
168 }
169
HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)170 HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2)
171 {
172 TLSSecureOptions secureOption;
173 secureOption.SetUseRemoteCipherPrefer(false);
174 bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer();
175 EXPECT_FALSE(isUseRemoteCipher);
176
177 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM);
178 std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms();
179 EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM);
180
181 secureOption.SetCipherSuite(CIPHER_SUITE);
182 std::string getCipherSuite = secureOption.GetCipherSuite();
183 EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE);
184
185 TLSSecureOptions copyOption = TLSSecureOptions(secureOption);
186 TLSSecureOptions equalOption = secureOption;
187
188 TLSConnectOptions connectOptions;
189 connectOptions.SetTlsSecureOptions(secureOption);
190 }
191
HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)192 HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2)
193 {
194 TLSSecureOptions secureOption;
195 TLSConnectOptions connectOptions;
196 connectOptions.SetTlsSecureOptions(secureOption);
197
198 Socket::NetAddress netAddress;
199 netAddress.SetAddress(IP_ADDRESS);
200 netAddress.SetPort(PORT);
201 connectOptions.SetNetAddress(netAddress);
202 Socket::NetAddress getNetAddress = connectOptions.GetNetAddress();
203 std::string address = getNetAddress.GetAddress();
204 EXPECT_STREQ(IP_ADDRESS, address.data());
205 int port = getNetAddress.GetPort();
206 EXPECT_EQ(port, PORT);
207 netAddress.SetFamilyBySaFamily(AF_INET6);
208 sa_family_t getFamily = netAddress.GetSaFamily();
209 EXPECT_EQ(getFamily, AF_INET6);
210
211 std::vector<std::string> alpnProtocols;
212 alpnProtocols.push_back(ALPN_PROTOCOL);
213 connectOptions.SetAlpnProtocols(alpnProtocols);
214 std::vector<std::string> getAlpnProtocols;
215 getAlpnProtocols = connectOptions.GetAlpnProtocols();
216 EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data());
217 }
218
HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)219 HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2)
220 {
221 TLSSecureOptions secureOption;
222 SecureData structureData(PRI_KEY_FILE);
223 secureOption.SetKey(structureData);
224 std::vector<std::string> caChain;
225 caChain.push_back(CA_CRT_FILE);
226 secureOption.SetCaChain(caChain);
227 secureOption.SetCert(CLIENT_FILE);
228
229 TLSConnectOptions connectOptions;
230 connectOptions.SetTlsSecureOptions(secureOption);
231
232 Socket::NetAddress netAddress;
233 netAddress.SetAddress(IP_ADDRESS);
234 netAddress.SetPort(0);
235 netAddress.SetFamilyBySaFamily(AF_INET);
236 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET);
237 }
238
HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)239 HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2)
240 {
241 TLSConnectOptions tlsConnectOptions = BaseOption();
242
243 AccessToken token;
244 TLSSocket tlsSocket;
245 tlsSocket.OnError(
246 [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); });
247 tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); });
248 std::string getData;
249 tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) {
250 EXPECT_STREQ(getData.data(), nullptr);
251 });
252 const std::string data = "how do you do?";
253 Socket::TCPSendOptions tcpSendOptions;
254 tcpSendOptions.SetData(data);
255 tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
256 tlsSocket.GetSignatureAlgorithms(
257 [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
258 tlsSocket.GetCertificate(
259 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); });
260 tlsSocket.GetCipherSuite(
261 [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
262 tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); });
263 tlsSocket.GetRemoteCertificate(
264 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); });
265 (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); });
266 }
267
HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2)268 HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2)
269 {
270 TLSConnectOptions connectOptions = BaseOption();
271
272 TLSSocket tlsSocket;
273 TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
274 bool isConnectToHost = tlsSocketInternal->TlsConnectToHost(SOCKET_FD, connectOptions, false);
275 EXPECT_FALSE(isConnectToHost);
276 tlsSocketInternal->SetTlsConfiguration(connectOptions);
277
278 bool sendSslNull = tlsSocketInternal->Send(SEND_DATA);
279 EXPECT_FALSE(sendSslNull);
280 char buffer[MAX_BUFFER_SIZE];
281 bzero(buffer, MAX_BUFFER_SIZE);
282 int recvSslNull = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
283 EXPECT_EQ(recvSslNull, SSL_ERROR_RETURN);
284 bool closeSslNull = tlsSocketInternal->Close();
285 EXPECT_FALSE(closeSslNull);
286 tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
287 bool sendEmpty = tlsSocketInternal->Send(SEND_DATA_EMPTY);
288 EXPECT_TRUE(sendEmpty);
289 int recv = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE);
290 EXPECT_EQ(recv, SSL_ERROR_RETURN);
291 bool close = tlsSocketInternal->Close();
292 EXPECT_FALSE(close);
293 delete tlsSocketInternal;
294 }
295
HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)296 HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2)
297 {
298 TLSSocket tlsSocket;
299 TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal();
300
301 std::vector<std::string> alpnProtocols;
302 alpnProtocols.push_back(ALPN_PROTOCOL);
303 bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
304 EXPECT_FALSE(alpnProSslNull);
305 std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite();
306 EXPECT_EQ(getCipherSuite.size(), 0);
307 bool setSharedSigals = tlsSocketInternal->SetSharedSigals();
308 EXPECT_FALSE(setSharedSigals);
309 tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method()));
310 getCipherSuite = tlsSocketInternal->GetCipherSuite();
311 EXPECT_NE(getCipherSuite.size(), 0);
312 setSharedSigals = tlsSocketInternal->SetSharedSigals();
313 EXPECT_FALSE(setSharedSigals);
314 TLSConnectOptions connectOptions = BaseOption();
315 bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols);
316 EXPECT_TRUE(alpnPro);
317
318 Socket::SocketRemoteInfo remoteInfo;
319 tlsSocketInternal->hostName_ = IP_ADDRESS;
320 tlsSocketInternal->port_ = PORT;
321 tlsSocketInternal->family_ = AF_INET;
322 tlsSocketInternal->MakeRemoteInfo(remoteInfo);
323 getCipherSuite = tlsSocketInternal->GetCipherSuite();
324 EXPECT_NE(getCipherSuite.size(), 0);
325
326 std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate();
327 EXPECT_EQ(getRemoteCert, "");
328
329 std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms();
330 EXPECT_EQ(getSignatureAlgorithms.size(), 0);
331
332 std::string getProtocol = tlsSocketInternal->GetProtocol();
333 EXPECT_NE(getProtocol, "");
334
335 setSharedSigals = tlsSocketInternal->SetSharedSigals();
336 EXPECT_FALSE(setSharedSigals);
337
338 ssl_st *ssl = tlsSocketInternal->GetSSL();
339 EXPECT_NE(ssl, nullptr);
340 delete tlsSocketInternal;
341 }
342 } // namespace TlsSocket
343 } // namespace NetStack
344 } // namespace OHOS
345