1// Copyright (c) 2023 Huawei Device Co., Ltd.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#[cfg(feature = "async")]
15mod async_utils;
16
17#[cfg(feature = "sync")]
18mod sync_utils;
19
20use tokio::runtime::Runtime;
21
22macro_rules! define_service_handle {
23    (
24        HTTP;
25    ) => {
26        use tokio::sync::mpsc::{Receiver, Sender};
27
28        pub struct HttpHandle {
29            pub port: u16,
30
31            // This channel allows the server to notify the client when it is up and running.
32            pub server_start: Receiver<()>,
33
34            // This channel allows the client to notify the server when it is ready to shut down.
35            pub client_shutdown: Sender<()>,
36
37            // This channel allows the server to notify the client when it has shut down.
38            pub server_shutdown: Receiver<()>,
39        }
40    };
41    (
42        HTTPS;
43    ) => {
44        pub struct TlsHandle {
45            pub port: u16,
46        }
47    };
48}
49
50#[macro_export]
51macro_rules! start_server {
52    (
53        HTTPS;
54        ServerNum: $server_num: expr,
55        Runtime: $runtime: expr,
56        Handles: $handle_vec: expr,
57        ServeFnName: $service_fn: ident,
58    ) => {{
59        for _i in 0..$server_num {
60            let (tx, rx) = std::sync::mpsc::channel();
61            let server_handle = $runtime.spawn(async move {
62                let handle = start_http_server!(
63                    HTTPS;
64                    $service_fn
65                );
66                tx.send(handle)
67                    .expect("Failed to send the handle to the test thread.");
68            });
69            $runtime
70                .block_on(server_handle)
71                .expect("Runtime start server coroutine failed");
72            let handle = rx
73                .recv()
74                .expect("Handle send channel (Server-Half) be closed unexpectedly");
75            $handle_vec.push(handle);
76        }
77    }};
78    (
79        HTTP;
80        ServerNum: $server_num: expr,
81        Runtime: $runtime: expr,
82        Handles: $handle_vec: expr,
83        ServeFnName: $service_fn: ident,
84    ) => {{
85        for _i in 0..$server_num {
86            let (tx, rx) = std::sync::mpsc::channel();
87            let server_handle = $runtime.spawn(async move {
88                let mut handle = start_http_server!(
89                    HTTP;
90                    $service_fn
91                );
92                handle
93                    .server_start
94                    .recv()
95                    .await
96                    .expect("Start channel (Server-Half) be closed unexpectedly");
97                tx.send(handle)
98                    .expect("Failed to send the handle to the test thread.");
99            });
100            $runtime
101                .block_on(server_handle)
102                .expect("Runtime start server coroutine failed");
103            let handle = rx
104                .recv()
105                .expect("Handle send channel (Server-Half) be closed unexpectedly");
106            $handle_vec.push(handle);
107        }
108    }};
109}
110
111#[macro_export]
112macro_rules! start_http_server {
113    (
114        HTTP;
115        $server_fn: ident
116    ) => {{
117        use hyper::service::{make_service_fn, service_fn};
118        use std::convert::Infallible;
119        use tokio::sync::mpsc::channel;
120
121        let (start_tx, start_rx) = channel::<()>(1);
122        let (client_tx, mut client_rx) = channel::<()>(1);
123        let (server_tx, server_rx) = channel::<()>(1);
124
125        let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !");
126        let addr = tcp_listener.local_addr().expect("get server local address failed!");
127        let port = addr.port();
128        let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !");
129
130        tokio::spawn(async move {
131            let make_svc =
132                make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) });
133            server
134                .serve(make_svc)
135                .with_graceful_shutdown(async {
136                    start_tx
137                        .send(())
138                        .await
139                        .expect("Start channel (Client-Half) be closed unexpectedly");
140                    client_rx
141                        .recv()
142                        .await
143                        .expect("Client channel (Client-Half) be closed unexpectedly");
144                })
145                .await
146                .expect("Start server failed");
147            server_tx
148                .send(())
149                .await
150                .expect("Server channel (Client-Half) be closed unexpectedly");
151        });
152
153        HttpHandle {
154            port,
155            server_start: start_rx,
156            client_shutdown: client_tx,
157            server_shutdown: server_rx,
158        }
159    }};
160    (
161        HTTPS;
162        $service_fn: ident
163    ) => {{
164        let mut port = 10000;
165        let listener = loop {
166            let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
167            match tokio::net::TcpListener::bind(addr).await {
168                Ok(listener) => break listener,
169                Err(_) => {
170                    port += 1;
171                    if port == u16::MAX {
172                        port = 10000;
173                    }
174                    continue;
175                }
176            }
177        };
178        let port = listener.local_addr().unwrap().port();
179
180        tokio::spawn(async move {
181            let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls())
182                .expect("SslAcceptorBuilder error");
183            acceptor
184                .set_session_id_context(b"test")
185                .expect("Set session id error");
186            acceptor
187                .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM)
188                .expect("Set private key error");
189            acceptor
190                .set_certificate_chain_file("tests/file/cert.pem")
191                .expect("Set cert error");
192            acceptor.set_alpn_protos(b"\x08http/1.1").unwrap();
193            acceptor.set_alpn_select_callback(|_, client| {
194                openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK)
195            });
196
197            let acceptor = acceptor.build();
198
199            let (stream, _) = listener.accept().await.expect("TCP listener accept error");
200            let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error");
201            let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error");
202            core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully
203
204            hyper::server::conn::Http::new()
205                .http1_only(true)
206                .http1_keep_alive(true)
207                .serve_connection(stream, hyper::service::service_fn($service_fn))
208                .await
209        });
210
211        TlsHandle {
212            port,
213        }
214    }};
215}
216
217/// Creates a `Request`.
218#[macro_export]
219#[cfg(feature = "sync")]
220macro_rules! ylong_request {
221    (
222        Request: {
223            Method: $method: expr,
224            Host: $host: expr,
225            Port: $port: expr,
226            $(
227                Header: $req_n: expr, $req_v: expr,
228            )*
229            Body: $req_body: expr,
230        },
231    ) => {
232        ylong_http::request::RequestBuilder::new()
233            .method($method)
234            .url(format!("{}:{}", $host, $port).as_str())
235            $(.header($req_n, $req_v))*
236            .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes()))
237            .expect("Request build failed")
238    };
239}
240
241/// Creates a `Request`.
242#[macro_export]
243#[cfg(feature = "async")]
244macro_rules! ylong_request {
245    (
246        Request: {
247            Method: $method: expr,
248            Host: $host: expr,
249            Port: $port: expr,
250            $(
251                Header: $req_n: expr, $req_v: expr,
252            )*
253            Body: $req_body: expr,
254        },
255    ) => {
256        ylong_http_client::async_impl::RequestBuilder::new()
257             .method($method)
258             .url(format!("{}:{}", $host, $port).as_str())
259             $(.header($req_n, $req_v))*
260             .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes()))
261             .expect("Request build failed")
262    };
263}
264
265/// Sets server async function.
266#[macro_export]
267macro_rules! set_server_fn {
268    (
269        ASYNC;
270        $server_fn_name: ident,
271        $(Request: {
272            Method: $method: expr,
273            $(
274                Header: $req_n: expr, $req_v: expr,
275            )*
276            Body: $req_body: expr,
277        },
278        Response: {
279            Status: $status: expr,
280            Version: $version: expr,
281            $(
282                Header: $resp_n: expr, $resp_v: expr,
283            )*
284            Body: $resp_body: expr,
285        },)*
286    ) => {
287        async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
288            match request.method().as_str() {
289                // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
290                $(
291                    $method => {
292                        assert_eq!($method, request.method().as_str(), "Assert request method failed");
293                        assert_eq!(
294                            "/",
295                            request.uri().to_string(),
296                            "Assert request host failed",
297                        );
298                        assert_eq!(
299                            $version,
300                            format!("{:?}", request.version()),
301                            "Assert request version failed",
302                        );
303                        $(assert_eq!(
304                            $req_v,
305                            request
306                                .headers()
307                                .get($req_n)
308                                .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
309                                .to_str()
310                                .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
311                            "Assert request header {} failed", $req_n,
312                        );)*
313                        let body = hyper::body::to_bytes(request.into_body()).await
314                            .expect("Get request body failed");
315                        assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
316                        Ok(
317                            hyper::Response::builder()
318                                .version(hyper::Version::HTTP_11)
319                                .status($status)
320                                $(.header($resp_n, $resp_v))*
321                                .body($resp_body.into())
322                                .expect("Build response failed")
323                        )
324                    },
325                )*
326                _ => {panic!("Unrecognized METHOD !");},
327            }
328        }
329
330    };
331    (
332        SYNC;
333        $server_fn_name: ident,
334        $(Request: {
335            Method: $method: expr,
336            $(
337                Header: $req_n: expr, $req_v: expr,
338            )*
339            Body: $req_body: expr,
340        },
341        Response: {
342            Status: $status: expr,
343            Version: $version: expr,
344            $(
345                Header: $resp_n: expr, $resp_v: expr,
346            )*
347            Body: $resp_body: expr,
348        },)*
349    ) => {
350        async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> {
351            match request.method().as_str() {
352                // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches.
353                $(
354                    $method => {
355                        assert_eq!($method, request.method().as_str(), "Assert request method failed");
356                        assert_eq!(
357                            "/",
358                            request.uri().to_string(),
359                            "Assert request uri failed",
360                        );
361                        assert_eq!(
362                            $version,
363                            format!("{:?}", request.version()),
364                            "Assert request version failed",
365                        );
366                        $(assert_eq!(
367                            $req_v,
368                            request
369                                .headers()
370                                .get($req_n)
371                                .expect(format!("Get request header \"{}\" failed", $req_n).as_str())
372                                .to_str()
373                                .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()),
374                            "Assert request header {} failed", $req_n,
375                        );)*
376                        let body = hyper::body::to_bytes(request.into_body()).await
377                            .expect("Get request body failed");
378                        assert_eq!($req_body.as_bytes(), body, "Assert request body failed");
379                        Ok(
380                            hyper::Response::builder()
381                                .version(hyper::Version::HTTP_11)
382                                .status($status)
383                                $(.header($resp_n, $resp_v))*
384                                .body($resp_body.into())
385                                .expect("Build response failed")
386                        )
387                    },
388                )*
389                _ => {panic!("Unrecognized METHOD !");},
390            }
391        }
392
393    };
394}
395
396#[macro_export]
397macro_rules! ensure_server_shutdown {
398    (ServerHandle: $handle:expr) => {
399        $handle
400            .client_shutdown
401            .send(())
402            .await
403            .expect("Client channel (Server-Half) be closed unexpectedly");
404        $handle
405            .server_shutdown
406            .recv()
407            .await
408            .expect("Server channel (Server-Half) be closed unexpectedly");
409    };
410}
411
412pub fn init_test_work_runtime(thread_num: usize) -> Runtime {
413    tokio::runtime::Builder::new_multi_thread()
414        .worker_threads(thread_num)
415        .enable_all()
416        .build()
417        .expect("Build runtime failed.")
418}
419