1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <arpa/inet.h>
17#include <csignal>
18#include <securec.h>
19#include <sys/un.h>
20
21#include "gtest/gtest.h"
22#include "client/websocket_client.h"
23#include "server/websocket_server.h"
24
25using namespace OHOS::ArkCompiler::Toolchain;
26
27namespace panda::test {
28class WebSocketTest : public testing::Test {
29public:
30    static void SetUpTestCase()
31    {
32        GTEST_LOG_(INFO) << "SetUpTestCase";
33        if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
34            GTEST_LOG_(ERROR) << "Reset SIGPIPE failed.";
35        }
36    }
37
38    static void TearDownTestCase()
39    {
40        GTEST_LOG_(INFO) << "TearDownCase";
41    }
42
43    void SetUp() override
44    {
45    }
46
47    void TearDown() override
48    {
49    }
50
51#if defined(OHOS_PLATFORM)
52    static constexpr char UNIX_DOMAIN_PATH_1[] = "server.sock_1";
53    static constexpr char UNIX_DOMAIN_PATH_2[] = "server.sock_2";
54    static constexpr char UNIX_DOMAIN_PATH_3[] = "server.sock_3";
55#endif
56    static constexpr char HELLO_SERVER[]    = "hello server";
57    static constexpr char HELLO_CLIENT[]    = "hello client";
58    static constexpr char SERVER_OK[]       = "server ok";
59    static constexpr char CLIENT_OK[]       = "client ok";
60    static constexpr char QUIT[]            = "quit";
61    static constexpr char PING[]            = "ping";
62    static constexpr int TCP_PORT           = 9230;
63    static const std::string LONG_MSG;
64    static const std::string LONG_LONG_MSG;
65};
66
67const std::string WebSocketTest::LONG_MSG       = std::string(1000, 'f');
68const std::string WebSocketTest::LONG_LONG_MSG  = std::string(0xfffff, 'f');
69
70HWTEST_F(WebSocketTest, ConnectWebSocketTest, testing::ext::TestSize.Level0)
71{
72    WebSocketServer serverSocket;
73    bool ret = false;
74#if defined(OHOS_PLATFORM)
75    int appPid = getpid();
76    ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH_1 + std::to_string(appPid), 5);
77#else
78    ret = serverSocket.InitTcpWebSocket(TCP_PORT, 5);
79#endif
80    ASSERT_TRUE(ret);
81    pid_t pid = fork();
82    if (pid == 0) {
83        // subprocess, handle client connect and recv/send message
84        // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess,
85        //       so testcase still success finally.
86        WebSocketClient clientSocket;
87        bool retClient = false;
88#if defined(OHOS_PLATFORM)
89        retClient = clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH_1 + std::to_string(appPid), 5);
90#else
91        retClient = clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5);
92#endif
93        ASSERT_TRUE(retClient);
94        retClient = clientSocket.ClientSendWSUpgradeReq();
95        ASSERT_TRUE(retClient);
96        retClient = clientSocket.ClientRecvWSUpgradeRsp();
97        ASSERT_TRUE(retClient);
98        retClient = clientSocket.SendReply(HELLO_SERVER);
99        EXPECT_TRUE(retClient);
100        std::string recv = clientSocket.Decode();
101        EXPECT_EQ(strcmp(recv.c_str(), HELLO_CLIENT), 0);
102        if (strcmp(recv.c_str(), HELLO_CLIENT) == 0) {
103            retClient = clientSocket.SendReply(CLIENT_OK);
104            EXPECT_TRUE(retClient);
105        }
106        retClient = clientSocket.SendReply(LONG_MSG);
107        EXPECT_TRUE(retClient);
108        recv = clientSocket.Decode();
109        EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0);
110        if (strcmp(recv.c_str(), SERVER_OK) == 0) {
111            retClient = clientSocket.SendReply(CLIENT_OK);
112            EXPECT_TRUE(retClient);
113        }
114        retClient = clientSocket.SendReply(LONG_LONG_MSG);
115        EXPECT_TRUE(retClient);
116        recv = clientSocket.Decode();
117        EXPECT_EQ(strcmp(recv.c_str(), SERVER_OK), 0);
118        if (strcmp(recv.c_str(), SERVER_OK) == 0) {
119            retClient = clientSocket.SendReply(CLIENT_OK);
120            EXPECT_TRUE(retClient);
121        }
122        retClient = clientSocket.SendReply(PING, FrameType::PING); // send a ping frame and wait for pong frame
123        EXPECT_TRUE(retClient);
124        recv = clientSocket.Decode(); // get the pong frame
125        EXPECT_EQ(strcmp(recv.c_str(), ""), 0); // pong frame has no data
126        retClient = clientSocket.SendReply(QUIT);
127        EXPECT_TRUE(retClient);
128        clientSocket.Close();
129        exit(0);
130    } else if (pid > 0) {
131        // mainprocess, handle server connect and recv/send message
132        auto openCallBack = []() -> void {
133            GTEST_LOG_(INFO) << "ConnectWebSocketTest connection is open.";
134        };
135        serverSocket.SetOpenConnectionCallback(openCallBack);
136
137        ret = serverSocket.AcceptNewConnection();
138        ASSERT_TRUE(ret);
139        std::string recv = serverSocket.Decode();
140        EXPECT_EQ(strcmp(recv.c_str(), HELLO_SERVER), 0);
141        serverSocket.SendReply(HELLO_CLIENT);
142        recv = serverSocket.Decode();
143        EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
144        recv = serverSocket.Decode();
145        EXPECT_EQ(strcmp(recv.c_str(), LONG_MSG.c_str()), 0);
146        serverSocket.SendReply(SERVER_OK);
147        recv = serverSocket.Decode();
148        EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
149        recv = serverSocket.Decode();
150        EXPECT_EQ(strcmp(recv.c_str(), LONG_LONG_MSG.c_str()), 0);
151        serverSocket.SendReply(SERVER_OK);
152        recv = serverSocket.Decode();
153        EXPECT_EQ(strcmp(recv.c_str(), CLIENT_OK), 0);
154        recv = serverSocket.Decode();
155        EXPECT_EQ(strcmp(recv.c_str(), PING), 0); // the ping frame has "PING" and send a pong frame
156        recv = serverSocket.Decode();
157        EXPECT_EQ(strcmp(recv.c_str(), QUIT), 0);
158        serverSocket.Close();
159    } else {
160        std::cerr << "ConnectWebSocketTest::fork failed, error = "
161                  << errno << ", desc = " << strerror(errno) << std::endl;
162    }
163}
164
165HWTEST_F(WebSocketTest, ReConnectWebSocketTest, testing::ext::TestSize.Level0)
166{
167    WebSocketServer serverSocket;
168    bool ret = false;
169#if defined(OHOS_PLATFORM)
170    int appPid = getpid();
171    ret = serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH_2 + std::to_string(appPid), 5);
172#else
173    ret = serverSocket.InitTcpWebSocket(TCP_PORT, 5);
174#endif
175    ASSERT_TRUE(ret);
176    for (int i = 0; i < 5; i++) {
177        pid_t pid = fork();
178        if (pid == 0) {
179            // subprocess, handle client connect and recv/send message
180            // note: EXPECT/ASSERT produce errors in subprocess that can not lead to failure of testcase in mainprocess,
181            //       so testcase still success finally.
182            WebSocketClient clientSocket;
183            bool retClient = false;
184#if defined(OHOS_PLATFORM)
185            retClient = clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH_2 + std::to_string(appPid), 5);
186#else
187            retClient = clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5);
188#endif
189            ASSERT_TRUE(retClient);
190            retClient = clientSocket.ClientSendWSUpgradeReq();
191            ASSERT_TRUE(retClient);
192            retClient = clientSocket.ClientRecvWSUpgradeRsp();
193            ASSERT_TRUE(retClient);
194            retClient = clientSocket.SendReply(HELLO_SERVER + std::to_string(i));
195            EXPECT_TRUE(retClient);
196            std::string recv = clientSocket.Decode();
197            EXPECT_EQ(strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()), 0);
198            if (strcmp(recv.c_str(), (HELLO_CLIENT + std::to_string(i)).c_str()) == 0) {
199                retClient = clientSocket.SendReply(CLIENT_OK + std::to_string(i));
200                EXPECT_TRUE(retClient);
201            }
202            clientSocket.Close();
203            exit(0);
204        } else if (pid > 0) {
205            // mainprocess, handle server connect and recv/send message
206            ret = serverSocket.AcceptNewConnection();
207            ASSERT_TRUE(ret);
208            std::string recv = serverSocket.Decode();
209            EXPECT_EQ(strcmp(recv.c_str(), (HELLO_SERVER + std::to_string(i)).c_str()), 0);
210            serverSocket.SendReply(HELLO_CLIENT + std::to_string(i));
211            recv = serverSocket.Decode();
212            EXPECT_EQ(strcmp(recv.c_str(), (CLIENT_OK + std::to_string(i)).c_str()), 0);
213            while (serverSocket.IsConnected()) {
214                serverSocket.Decode();
215            }
216        } else {
217            std::cerr << "ReConnectWebSocketTest::fork failed, error = "
218                      << errno << ", desc = " << strerror(errno) << std::endl;
219        }
220    }
221    serverSocket.Close();
222}
223
224HWTEST_F(WebSocketTest, ClientAbnormalTest, testing::ext::TestSize.Level0)
225{
226    WebSocketClient clientSocket;
227    ASSERT_STREQ(clientSocket.GetSocketStateString().c_str(), "closed");
228    ASSERT_FALSE(clientSocket.ClientSendWSUpgradeReq());
229    ASSERT_FALSE(clientSocket.ClientRecvWSUpgradeRsp());
230    ASSERT_FALSE(clientSocket.SendReply(HELLO_SERVER));
231}
232
233HWTEST_F(WebSocketTest, ServerAbnormalTest, testing::ext::TestSize.Level0)
234{
235    WebSocketServer serverSocket;
236    // No connection established, the function returns directly.
237    serverSocket.Close();
238    ASSERT_FALSE(serverSocket.AcceptNewConnection());
239
240#if defined(OHOS_PLATFORM)
241    int appPid = getpid();
242    ASSERT_TRUE(serverSocket.InitUnixWebSocket(UNIX_DOMAIN_PATH_3 + std::to_string(appPid), 5));
243#else
244    ASSERT_TRUE(serverSocket.InitTcpWebSocket(TCP_PORT, 5));
245#endif
246    pid_t pid = fork();
247    if (pid == 0) {
248        WebSocketClient clientSocket;
249        auto closeCallBack = []() -> void {
250            GTEST_LOG_(INFO) << "ServerAbnormalTest client connection is closed.";
251        };
252        clientSocket.SetCloseConnectionCallback(closeCallBack);
253
254#if defined(OHOS_PLATFORM)
255        ASSERT_TRUE(clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH_3 + std::to_string(appPid), 5));
256        // state is not UNITED, the function returns directly.
257        ASSERT_TRUE(clientSocket.InitToolchainWebSocketForSockName(UNIX_DOMAIN_PATH_3 + std::to_string(appPid), 5));
258#else
259        ASSERT_TRUE(clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5));
260        // state is not UNITED, the function returns directly.
261        ASSERT_TRUE(clientSocket.InitToolchainWebSocketForPort(TCP_PORT, 5));
262#endif
263        ASSERT_TRUE(clientSocket.ClientSendWSUpgradeReq());
264        ASSERT_FALSE(clientSocket.ClientRecvWSUpgradeRsp());
265        exit(0);
266    } else if (pid > 0) {
267        auto failCallBack = []() -> void {
268            GTEST_LOG_(INFO) << "ServerAbnormalTest server connection is failed.";
269        };
270        serverSocket.SetFailConnectionCallback(failCallBack);
271        auto notValidCallBack = [](const HttpRequest&) -> bool {
272            GTEST_LOG_(INFO) << "ServerAbnormalTest server connection request is not valid.";
273            return false;
274        };
275        serverSocket.SetValidateConnectionCallback(notValidCallBack);
276        ASSERT_FALSE(serverSocket.AcceptNewConnection());
277    } else {
278        std::cerr << "ServerAbnormalTest::fork failed, error = "
279                  << errno << ", desc = " << strerror(errno) << std::endl;
280    }
281}
282}  // namespace panda::test