1/* 2 * Copyright (c) 2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include <gtest/gtest.h> 17#include <iostream> 18 19#include <openssl/ssl.h> 20 21#define private public 22#include "accesstoken_kit.h" 23#include "tls_socket.h" 24#include "socket_remote_info.h" 25#include "token_setproc.h" 26#include "tls.h" 27#include "TlsTest.h" 28 29namespace OHOS { 30namespace NetStack { 31namespace TlsSocket { 32namespace { 33using namespace testing::ext; 34using namespace Security::AccessToken; 35using Security::AccessToken::AccessTokenID; 36static constexpr const char *KEY_PASS = ""; 37static constexpr const char *PROTOCOL12 = "TLSv1.2"; 38static constexpr const char *PROTOCOL13 = "TLSv1.3"; 39static constexpr const char *IP_ADDRESS = "127.0.0.1"; 40static constexpr const char *ALPN_PROTOCOL = "http/1.1"; 41static constexpr const char *SIGNATURE_ALGORITHM = "rsa_pss_rsae_sha256:ECDSA+SHA256"; 42static constexpr const char *CIPHER_SUITE = "AES256-SHA256"; 43static constexpr const char *SEND_DATA = "How do you do"; 44static constexpr const char *SEND_DATA_EMPTY = ""; 45static constexpr const size_t MAX_BUFFER_SIZE = 8192; 46const int PORT = 7838; 47const int SOCKET_FD = 5; 48const int SSL_ERROR_RETURN = -1; 49 50TLSConnectOptions BaseOption() 51{ 52 TLSSecureOptions secureOption; 53 SecureData structureData(PRI_KEY_FILE); 54 secureOption.SetKey(structureData); 55 std::vector<std::string> caChain; 56 caChain.push_back(CA_CRT_FILE); 57 secureOption.SetCaChain(caChain); 58 secureOption.SetCert(CLIENT_FILE); 59 secureOption.SetCipherSuite(CIPHER_SUITE); 60 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM); 61 std::vector<std::string> protocol; 62 protocol.push_back(PROTOCOL13); 63 secureOption.SetProtocolChain(protocol); 64 65 TLSConnectOptions connectOptions; 66 connectOptions.SetTlsSecureOptions(secureOption); 67 Socket::NetAddress netAddress; 68 netAddress.SetAddress(IP_ADDRESS); 69 netAddress.SetPort(0); 70 netAddress.SetFamilyBySaFamily(AF_INET); 71 connectOptions.SetNetAddress(netAddress); 72 std::vector<std::string> alpnProtocols; 73 alpnProtocols.push_back(ALPN_PROTOCOL); 74 connectOptions.SetAlpnProtocols(alpnProtocols); 75 return connectOptions; 76} 77 78HapInfoParams testInfoParms = {.bundleName = "TlsSocketBranchTest", 79 .userID = 1, 80 .instIndex = 0, 81 .appIDDesc = "test", 82 .isSystemApp = true}; 83 84PermissionDef testPermDef = { 85 .permissionName = "ohos.permission.INTERNET", 86 .bundleName = "TlsSocketBranchTest", 87 .grantMode = 1, 88 .label = "label", 89 .labelId = 1, 90 .description = "Test Tls Socket Branch", 91 .descriptionId = 1, 92 .availableLevel = APL_SYSTEM_BASIC, 93}; 94 95PermissionStateFull testState = { 96 .grantFlags = {2}, 97 .grantStatus = {PermissionState::PERMISSION_GRANTED}, 98 .isGeneral = true, 99 .permissionName = "ohos.permission.INTERNET", 100 .resDeviceID = {"local"}, 101}; 102 103HapPolicyParams testPolicyPrams = { 104 .apl = APL_SYSTEM_BASIC, 105 .domain = "test.domain", 106 .permList = {testPermDef}, 107 .permStateList = {testState}, 108}; 109} // namespace 110 111class AccessToken { 112public: 113 AccessToken() : currentID_(GetSelfTokenID()) 114 { 115 AccessTokenIDEx tokenIdEx = AccessTokenKit::AllocHapToken(testInfoParms, testPolicyPrams); 116 accessID_ = tokenIdEx.tokenIdExStruct.tokenID; 117 SetSelfTokenID(tokenIdEx.tokenIDEx); 118 } 119 ~AccessToken() 120 { 121 AccessTokenKit::DeleteToken(accessID_); 122 SetSelfTokenID(currentID_); 123 } 124 125private: 126 AccessTokenID currentID_; 127 AccessTokenID accessID_ = 0; 128}; 129 130class TlsSocketBranchTest : public testing::Test { 131public: 132 static void SetUpTestCase() {} 133 134 static void TearDownTestCase() {} 135 136 virtual void SetUp() {} 137 138 virtual void TearDown() {} 139}; 140 141HWTEST_F(TlsSocketBranchTest, BranchTest1, TestSize.Level2) 142{ 143 TLSSecureOptions secureOption; 144 SecureData structureData(PRI_KEY_FILE); 145 secureOption.SetKey(structureData); 146 147 SecureData keyPass(KEY_PASS); 148 secureOption.SetKeyPass(keyPass); 149 SecureData secureData = secureOption.GetKey(); 150 EXPECT_EQ(structureData.Length(), strlen(PRI_KEY_FILE)); 151 std::vector<std::string> caChain; 152 caChain.push_back(CA_CRT_FILE); 153 secureOption.SetCaChain(caChain); 154 std::vector<std::string> getCaChain = secureOption.GetCaChain(); 155 EXPECT_NE(getCaChain.data(), nullptr); 156 157 secureOption.SetCert(CLIENT_FILE); 158 std::string getCert = secureOption.GetCert(); 159 EXPECT_NE(getCert.data(), nullptr); 160 161 std::vector<std::string> protocolVec = {PROTOCOL12, PROTOCOL13}; 162 secureOption.SetProtocolChain(protocolVec); 163 std::vector<std::string> getProtocol; 164 getProtocol = secureOption.GetProtocolChain(); 165 166 TLSSecureOptions copyOption = TLSSecureOptions(secureOption); 167 TLSSecureOptions equalOption = secureOption; 168} 169 170HWTEST_F(TlsSocketBranchTest, BranchTest2, TestSize.Level2) 171{ 172 TLSSecureOptions secureOption; 173 secureOption.SetUseRemoteCipherPrefer(false); 174 bool isUseRemoteCipher = secureOption.UseRemoteCipherPrefer(); 175 EXPECT_FALSE(isUseRemoteCipher); 176 177 secureOption.SetSignatureAlgorithms(SIGNATURE_ALGORITHM); 178 std::string getSignatureAlgorithm = secureOption.GetSignatureAlgorithms(); 179 EXPECT_STREQ(getSignatureAlgorithm.data(), SIGNATURE_ALGORITHM); 180 181 secureOption.SetCipherSuite(CIPHER_SUITE); 182 std::string getCipherSuite = secureOption.GetCipherSuite(); 183 EXPECT_STREQ(getCipherSuite.data(), CIPHER_SUITE); 184 185 TLSSecureOptions copyOption = TLSSecureOptions(secureOption); 186 TLSSecureOptions equalOption = secureOption; 187 188 TLSConnectOptions connectOptions; 189 connectOptions.SetTlsSecureOptions(secureOption); 190} 191 192HWTEST_F(TlsSocketBranchTest, BranchTest3, TestSize.Level2) 193{ 194 TLSSecureOptions secureOption; 195 TLSConnectOptions connectOptions; 196 connectOptions.SetTlsSecureOptions(secureOption); 197 198 Socket::NetAddress netAddress; 199 netAddress.SetAddress(IP_ADDRESS); 200 netAddress.SetPort(PORT); 201 connectOptions.SetNetAddress(netAddress); 202 Socket::NetAddress getNetAddress = connectOptions.GetNetAddress(); 203 std::string address = getNetAddress.GetAddress(); 204 EXPECT_STREQ(IP_ADDRESS, address.data()); 205 int port = getNetAddress.GetPort(); 206 EXPECT_EQ(port, PORT); 207 netAddress.SetFamilyBySaFamily(AF_INET6); 208 sa_family_t getFamily = netAddress.GetSaFamily(); 209 EXPECT_EQ(getFamily, AF_INET6); 210 211 std::vector<std::string> alpnProtocols; 212 alpnProtocols.push_back(ALPN_PROTOCOL); 213 connectOptions.SetAlpnProtocols(alpnProtocols); 214 std::vector<std::string> getAlpnProtocols; 215 getAlpnProtocols = connectOptions.GetAlpnProtocols(); 216 EXPECT_STREQ(getAlpnProtocols[0].data(), alpnProtocols[0].data()); 217} 218 219HWTEST_F(TlsSocketBranchTest, BranchTest4, TestSize.Level2) 220{ 221 TLSSecureOptions secureOption; 222 SecureData structureData(PRI_KEY_FILE); 223 secureOption.SetKey(structureData); 224 std::vector<std::string> caChain; 225 caChain.push_back(CA_CRT_FILE); 226 secureOption.SetCaChain(caChain); 227 secureOption.SetCert(CLIENT_FILE); 228 229 TLSConnectOptions connectOptions; 230 connectOptions.SetTlsSecureOptions(secureOption); 231 232 Socket::NetAddress netAddress; 233 netAddress.SetAddress(IP_ADDRESS); 234 netAddress.SetPort(0); 235 netAddress.SetFamilyBySaFamily(AF_INET); 236 EXPECT_EQ(netAddress.GetSaFamily(), AF_INET); 237} 238 239HWTEST_F(TlsSocketBranchTest, BranchTest5, TestSize.Level2) 240{ 241 TLSConnectOptions tlsConnectOptions = BaseOption(); 242 243 AccessToken token; 244 TLSSocket tlsSocket; 245 tlsSocket.OnError( 246 [](int32_t errorNumber, const std::string &errorString) { EXPECT_NE(TLSSOCKET_SUCCESS, errorNumber); }); 247 tlsSocket.Connect(tlsConnectOptions, [](int32_t errCode) { EXPECT_NE(TLSSOCKET_SUCCESS, errCode); }); 248 std::string getData; 249 tlsSocket.OnMessage([&getData](const std::string &data, const Socket::SocketRemoteInfo &remoteInfo) { 250 EXPECT_STREQ(getData.data(), nullptr); 251 }); 252 const std::string data = "how do you do?"; 253 Socket::TCPSendOptions tcpSendOptions; 254 tcpSendOptions.SetData(data); 255 tlsSocket.Send(tcpSendOptions, [](int32_t errCode) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); }); 256 tlsSocket.GetSignatureAlgorithms( 257 [](int32_t errCode, const std::vector<std::string> &algorithms) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); }); 258 tlsSocket.GetCertificate( 259 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_NE(errCode, TLSSOCKET_SUCCESS); }); 260 tlsSocket.GetCipherSuite( 261 [](int32_t errCode, const std::vector<std::string> &suite) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); }); 262 tlsSocket.GetProtocol([](int32_t errCode, const std::string &protocol) { EXPECT_EQ(errCode, TLSSOCKET_SUCCESS); }); 263 tlsSocket.GetRemoteCertificate( 264 [](int32_t errCode, const X509CertRawData &cert) { EXPECT_EQ(errCode, TLS_ERR_SSL_NULL); }); 265 (void)tlsSocket.Close([](int32_t errCode) { EXPECT_FALSE(errCode == TLSSOCKET_SUCCESS); }); 266} 267 268HWTEST_F(TlsSocketBranchTest, BranchTest6, TestSize.Level2) 269{ 270 TLSConnectOptions connectOptions = BaseOption(); 271 272 TLSSocket tlsSocket; 273 TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal(); 274 bool isConnectToHost = tlsSocketInternal->TlsConnectToHost(SOCKET_FD, connectOptions, false); 275 EXPECT_FALSE(isConnectToHost); 276 tlsSocketInternal->SetTlsConfiguration(connectOptions); 277 278 bool sendSslNull = tlsSocketInternal->Send(SEND_DATA); 279 EXPECT_FALSE(sendSslNull); 280 char buffer[MAX_BUFFER_SIZE]; 281 bzero(buffer, MAX_BUFFER_SIZE); 282 int recvSslNull = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE); 283 EXPECT_EQ(recvSslNull, SSL_ERROR_RETURN); 284 bool closeSslNull = tlsSocketInternal->Close(); 285 EXPECT_FALSE(closeSslNull); 286 tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method())); 287 bool sendEmpty = tlsSocketInternal->Send(SEND_DATA_EMPTY); 288 EXPECT_FALSE(sendEmpty); 289 int recv = tlsSocketInternal->Recv(buffer, MAX_BUFFER_SIZE); 290 EXPECT_EQ(recv, SSL_ERROR_RETURN); 291 bool close = tlsSocketInternal->Close(); 292 EXPECT_FALSE(close); 293 delete tlsSocketInternal; 294} 295 296HWTEST_F(TlsSocketBranchTest, BranchTest7, TestSize.Level2) 297{ 298 TLSSocket tlsSocket; 299 TLSSocket::TLSSocketInternal *tlsSocketInternal = new TLSSocket::TLSSocketInternal(); 300 301 std::vector<std::string> alpnProtocols; 302 alpnProtocols.push_back(ALPN_PROTOCOL); 303 bool alpnProSslNull = tlsSocketInternal->SetAlpnProtocols(alpnProtocols); 304 EXPECT_FALSE(alpnProSslNull); 305 std::vector<std::string> getCipherSuite = tlsSocketInternal->GetCipherSuite(); 306 EXPECT_EQ(getCipherSuite.size(), 0); 307 bool setSharedSigals = tlsSocketInternal->SetSharedSigals(); 308 EXPECT_FALSE(setSharedSigals); 309 tlsSocketInternal->ssl_ = SSL_new(SSL_CTX_new(TLS_client_method())); 310 getCipherSuite = tlsSocketInternal->GetCipherSuite(); 311 EXPECT_NE(getCipherSuite.size(), 0); 312 setSharedSigals = tlsSocketInternal->SetSharedSigals(); 313 EXPECT_FALSE(setSharedSigals); 314 TLSConnectOptions connectOptions = BaseOption(); 315 bool alpnPro = tlsSocketInternal->SetAlpnProtocols(alpnProtocols); 316 EXPECT_TRUE(alpnPro); 317 318 Socket::SocketRemoteInfo remoteInfo; 319 tlsSocketInternal->hostName_ = IP_ADDRESS; 320 tlsSocketInternal->port_ = PORT; 321 tlsSocketInternal->family_ = AF_INET; 322 tlsSocketInternal->MakeRemoteInfo(remoteInfo); 323 getCipherSuite = tlsSocketInternal->GetCipherSuite(); 324 EXPECT_NE(getCipherSuite.size(), 0); 325 326 std::string getRemoteCert = tlsSocketInternal->GetRemoteCertificate(); 327 EXPECT_EQ(getRemoteCert, ""); 328 329 std::vector<std::string> getSignatureAlgorithms = tlsSocketInternal->GetSignatureAlgorithms(); 330 EXPECT_EQ(getSignatureAlgorithms.size(), 0); 331 332 std::string getProtocol = tlsSocketInternal->GetProtocol(); 333 EXPECT_NE(getProtocol, ""); 334 335 setSharedSigals = tlsSocketInternal->SetSharedSigals(); 336 EXPECT_FALSE(setSharedSigals); 337 338 ssl_st *ssl = tlsSocketInternal->GetSSL(); 339 EXPECT_NE(ssl, nullptr); 340 delete tlsSocketInternal; 341} 342} // namespace TlsSocket 343} // namespace NetStack 344} // namespace OHOS 345