1/* 2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include <fstream> 17#include <gtest/gtest.h> 18#include <iostream> 19#include <openssl/rsa.h> 20#include <openssl/ssl.h> 21#include <sstream> 22#include <string> 23#include <string_view> 24#include <unistd.h> 25#include <vector> 26 27#include "net_address.h" 28#include "secure_data.h" 29#include "socket_error.h" 30#include "socket_state_base.h" 31#include "tls.h" 32#include "tls_certificate.h" 33#include "tls_configuration.h" 34#include "tls_key.h" 35#include "tls_socket_server.h" 36#include "tls_socket.h" 37 38namespace OHOS { 39namespace NetStack { 40namespace TlsSocketServer { 41namespace { 42const std::string_view CA_DER = "/data/ClientCert/ca.crt"; 43const std::string_view IP_ADDRESS = "/data/Ip/address.txt"; 44const std::string_view PORT = "/data/Ip/port.txt"; 45 46inline bool CheckCaFileExistence(const char *function) 47{ 48 if (access(CA_DER.data(), 0)) { 49 std::cout << "CA file does not exist! (" << function << ")"; 50 return false; 51 } 52 return true; 53} 54 55std::string ChangeToFile(std::string_view fileName) 56{ 57 std::ifstream file; 58 file.open(fileName); 59 std::stringstream ss; 60 ss << file.rdbuf(); 61 std::string infos = ss.str(); 62 file.close(); 63 return infos; 64} 65 66 67std::string GetIp(std::string ip) 68{ 69 return ip.substr(0, ip.length() - 1); 70} 71 72} // namespace 73class TlsSocketServerTest : public testing::Test { 74public: 75 static void SetUpTestCase() {} 76 77 static void TearDownTestCase() {} 78 79 virtual void SetUp() {} 80 81 virtual void TearDown() {} 82}; 83 84HWTEST_F(TlsSocketServerTest, ListenInterface, testing::ext::TestSize.Level2) 85{ 86 if (!CheckCaFileExistence("ListenInterface")) { 87 return; 88 } 89 TLSSocketServer server; 90 TlsSocket::TLSConnectOptions tlsListenOptions; 91 92 server.Listen(tlsListenOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 93} 94 95HWTEST_F(TlsSocketServerTest, sendInterface, testing::ext::TestSize.Level2) 96{ 97 if (!CheckCaFileExistence("sendInterface")) { 98 return; 99 } 100 101 TLSSocketServer server; 102 103 TLSServerSendOptions tlsServerSendOptions; 104 105 const std::string data = "how do you do? this is sendInterface"; 106 tlsServerSendOptions.SetSendData(data); 107 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 108} 109 110HWTEST_F(TlsSocketServerTest, closeInterface, testing::ext::TestSize.Level2) 111{ 112 if (!CheckCaFileExistence("closeInterface")) { 113 return; 114 } 115 116 TLSSocketServer server; 117 118 const std::string data = "how do you do? this is closeInterface"; 119 TLSServerSendOptions tlsServerSendOptions; 120 tlsServerSendOptions.SetSendData(data); 121 int socketFd = tlsServerSendOptions.GetSocket(); 122 123 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 124 sleep(2); 125 126 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 127} 128 129HWTEST_F(TlsSocketServerTest, stopInterface, testing::ext::TestSize.Level2) 130{ 131 if (!CheckCaFileExistence("stopInterface")) { 132 return; 133 } 134 135 TLSSocketServer server; 136 137 TLSServerSendOptions tlsServerSendOptions; 138 int socketFd = tlsServerSendOptions.GetSocket(); 139 140 141 const std::string data = "how do you do? this is stopInterface"; 142 tlsServerSendOptions.SetSendData(data); 143 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 144 sleep(2); 145 146 147 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 148 sleep(2); 149 150 151 server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 152} 153 154HWTEST_F(TlsSocketServerTest, getRemoteAddressInterface, testing::ext::TestSize.Level2) 155{ 156 if (!CheckCaFileExistence("getRemoteAddressInterface")) { 157 return; 158 } 159 160 TLSSocketServer server; 161 162 TLSServerSendOptions tlsServerSendOptions; 163 int socketFd = tlsServerSendOptions.GetSocket(); 164 Socket::NetAddress address; 165 166 address.SetAddress(GetIp(ChangeToFile(IP_ADDRESS))); 167 address.SetPort(std::atoi(ChangeToFile(PORT).c_str())); 168 address.SetFamilyBySaFamily(AF_INET); 169 170 Socket::NetAddress netAddress; 171 server.GetRemoteAddress(socketFd, [&netAddress](int32_t errCode, 172 const Socket::NetAddress &address) { 173 EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); 174 netAddress.SetAddress(address.GetAddress()); 175 netAddress.SetPort(address.GetPort()); 176 netAddress.SetFamilyBySaFamily(address.GetSaFamily()); 177 }); 178 179 const std::string data = "how do you do? this is getRemoteAddressInterface"; 180 tlsServerSendOptions.SetSendData(data); 181 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 182 sleep(2); 183 184 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 185 sleep(2); 186 187 server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 188} 189 190HWTEST_F(TlsSocketServerTest, getRemoteCertificateInterface, testing::ext::TestSize.Level2) 191{ 192 if (!CheckCaFileExistence("getRemoteCertificateInterface")) { 193 return; 194 } 195 196 TLSSocketServer server; 197 198 TLSServerSendOptions tlsServerSendOptions; 199 int socketFd = tlsServerSendOptions.GetSocket(); 200 201 202 const std::string data = "how do you do? This is UT test getRemoteCertificateInterface"; 203 tlsServerSendOptions.SetSendData(data); 204 server.Send(tlsServerSendOptions, [](int32_t errCode) { 205 EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 206 sleep(2); 207 208 server.GetRemoteCertificate(socketFd, [](int32_t errCode, const TlsSocket::X509CertRawData &cert) { 209 EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 210 211 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 212 sleep(2); 213 214 server.Stop([](int32_t errCode) { EXPECT_TRUE(errCode == TlsSocket::TLSSOCKET_SUCCESS); }); 215} 216 217HWTEST_F(TlsSocketServerTest, getCertificateInterface, testing::ext::TestSize.Level2) 218{ 219 if (!CheckCaFileExistence("getCertificateInterface")) { 220 return; 221 } 222 TLSSocketServer server; 223 224 const std::string data = "how do you do? This is UT test getCertificateInterface"; 225 TLSServerSendOptions tlsServerSendOptions; 226 tlsServerSendOptions.SetSendData(data); 227 int socketFd = tlsServerSendOptions.GetSocket(); 228 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 229 230 server.GetCertificate( 231 [](int32_t errCode, const TlsSocket::X509CertRawData &cert) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 232 233 sleep(2); 234 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 235} 236 237HWTEST_F(TlsSocketServerTest, protocolInterface, testing::ext::TestSize.Level2) 238{ 239 if (!CheckCaFileExistence("protocolInterface")) { 240 return; 241 } 242 TLSSocketServer server; 243 244 const std::string data = "how do you do? this is protocolInterface"; 245 TLSServerSendOptions tlsServerSendOptions; 246 tlsServerSendOptions.SetSendData(data); 247 248 int socketFd = tlsServerSendOptions.GetSocket(); 249 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 250 std::string getProtocolVal; 251 server.GetProtocol([&getProtocolVal](int32_t errCode, const std::string &protocol) { 252 EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); 253 getProtocolVal = protocol; 254 }); 255 EXPECT_STREQ(getProtocolVal.c_str(), "TLSv1.3"); 256 257 Socket::SocketStateBase stateBase; 258 server.GetState([&stateBase](int32_t errCode, Socket::SocketStateBase state) { 259 if (TlsSocket::TLSSOCKET_SUCCESS) { 260 EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); 261 stateBase.SetIsBound(state.IsBound()); 262 stateBase.SetIsClose(state.IsClose()); 263 stateBase.SetIsConnected(state.IsConnected()); 264 } 265 }); 266 EXPECT_TRUE(stateBase.IsConnected()); 267 sleep(2); 268 269 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 270} 271 272HWTEST_F(TlsSocketServerTest, getSignatureAlgorithmsInterface, testing::ext::TestSize.Level2) 273{ 274 if (!CheckCaFileExistence("getSignatureAlgorithmsInterface")) { 275 return; 276 } 277 278 TLSSocketServer server; 279 TlsSocket::TLSSecureOptions secureOption; 280 281 const std::string data = "how do you do? this is getSigntureAlgorithmsInterface"; 282 TLSServerSendOptions tlsServerSendOptions; 283 tlsServerSendOptions.SetSendData(data); 284 285 int socketFd = tlsServerSendOptions.GetSocket(); 286 server.Send(tlsServerSendOptions, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 287 sleep(2); 288 289 bool testFlag = false; 290 std::string signatureAlgorithmVec = {"rsa_pss_rsae_sha256:ECDSA+SHA256"}; 291 secureOption.SetSignatureAlgorithms(signatureAlgorithmVec); 292 std::vector<std::string> testSignatureAlgorithms; 293 server.GetSignatureAlgorithms(socketFd, [&testSignatureAlgorithms](int32_t errCode, 294 const std::vector<std::string> &algorithms) { 295 if (errCode == TlsSocket::TLSSOCKET_SUCCESS) { 296 testSignatureAlgorithms = algorithms; 297 } 298 }); 299 for (auto const &iter : testSignatureAlgorithms) { 300 if (iter == "ECDSA+SHA256") { 301 testFlag = true; 302 } 303 } 304 EXPECT_TRUE(testFlag); 305 sleep(2); 306 307 308 (void)server.Close(socketFd, [](int32_t errCode) { EXPECT_TRUE(TlsSocket::TLSSOCKET_SUCCESS); }); 309} 310 311 312} //TlsSocketServer 313} //NetStack 314} //OHOS 315