1// Copyright (c) 2023 Huawei Device Co., Ltd. 2// Licensed under the Apache License, Version 2.0 (the "License"); 3// you may not use this file except in compliance with the License. 4// You may obtain a copy of the License at 5// 6// http://www.apache.org/licenses/LICENSE-2.0 7// 8// Unless required by applicable law or agreed to in writing, software 9// distributed under the License is distributed on an "AS IS" BASIS, 10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11// See the License for the specific language governing permissions and 12// limitations under the License. 13 14#[cfg(feature = "async")] 15mod async_utils; 16 17#[cfg(feature = "sync")] 18mod sync_utils; 19 20use tokio::runtime::Runtime; 21 22macro_rules! define_service_handle { 23 ( 24 HTTP; 25 ) => { 26 use tokio::sync::mpsc::{Receiver, Sender}; 27 28 pub struct HttpHandle { 29 pub port: u16, 30 31 // This channel allows the server to notify the client when it is up and running. 32 pub server_start: Receiver<()>, 33 34 // This channel allows the client to notify the server when it is ready to shut down. 35 pub client_shutdown: Sender<()>, 36 37 // This channel allows the server to notify the client when it has shut down. 38 pub server_shutdown: Receiver<()>, 39 } 40 }; 41 ( 42 HTTPS; 43 ) => { 44 pub struct TlsHandle { 45 pub port: u16, 46 } 47 }; 48} 49 50#[macro_export] 51macro_rules! start_server { 52 ( 53 HTTPS; 54 ServerNum: $server_num: expr, 55 Runtime: $runtime: expr, 56 Handles: $handle_vec: expr, 57 ServeFnName: $service_fn: ident, 58 ) => {{ 59 for _i in 0..$server_num { 60 let (tx, rx) = std::sync::mpsc::channel(); 61 let server_handle = $runtime.spawn(async move { 62 let handle = start_http_server!( 63 HTTPS; 64 $service_fn 65 ); 66 tx.send(handle) 67 .expect("Failed to send the handle to the test thread."); 68 }); 69 $runtime 70 .block_on(server_handle) 71 .expect("Runtime start server coroutine failed"); 72 let handle = rx 73 .recv() 74 .expect("Handle send channel (Server-Half) be closed unexpectedly"); 75 $handle_vec.push(handle); 76 } 77 }}; 78 ( 79 HTTP; 80 ServerNum: $server_num: expr, 81 Runtime: $runtime: expr, 82 Handles: $handle_vec: expr, 83 ServeFnName: $service_fn: ident, 84 ) => {{ 85 for _i in 0..$server_num { 86 let (tx, rx) = std::sync::mpsc::channel(); 87 let server_handle = $runtime.spawn(async move { 88 let mut handle = start_http_server!( 89 HTTP; 90 $service_fn 91 ); 92 handle 93 .server_start 94 .recv() 95 .await 96 .expect("Start channel (Server-Half) be closed unexpectedly"); 97 tx.send(handle) 98 .expect("Failed to send the handle to the test thread."); 99 }); 100 $runtime 101 .block_on(server_handle) 102 .expect("Runtime start server coroutine failed"); 103 let handle = rx 104 .recv() 105 .expect("Handle send channel (Server-Half) be closed unexpectedly"); 106 $handle_vec.push(handle); 107 } 108 }}; 109} 110 111#[macro_export] 112macro_rules! start_http_server { 113 ( 114 HTTP; 115 $server_fn: ident 116 ) => {{ 117 use hyper::service::{make_service_fn, service_fn}; 118 use std::convert::Infallible; 119 use tokio::sync::mpsc::channel; 120 121 let (start_tx, start_rx) = channel::<()>(1); 122 let (client_tx, mut client_rx) = channel::<()>(1); 123 let (server_tx, server_rx) = channel::<()>(1); 124 125 let tcp_listener = std::net::TcpListener::bind("127.0.0.1:0").expect("server bind port failed !"); 126 let addr = tcp_listener.local_addr().expect("get server local address failed!"); 127 let port = addr.port(); 128 let server = hyper::Server::from_tcp(tcp_listener).expect("build hyper server from tcp listener failed !"); 129 130 tokio::spawn(async move { 131 let make_svc = 132 make_service_fn(|_conn| async { Ok::<_, Infallible>(service_fn($server_fn)) }); 133 server 134 .serve(make_svc) 135 .with_graceful_shutdown(async { 136 start_tx 137 .send(()) 138 .await 139 .expect("Start channel (Client-Half) be closed unexpectedly"); 140 client_rx 141 .recv() 142 .await 143 .expect("Client channel (Client-Half) be closed unexpectedly"); 144 }) 145 .await 146 .expect("Start server failed"); 147 server_tx 148 .send(()) 149 .await 150 .expect("Server channel (Client-Half) be closed unexpectedly"); 151 }); 152 153 HttpHandle { 154 port, 155 server_start: start_rx, 156 client_shutdown: client_tx, 157 server_shutdown: server_rx, 158 } 159 }}; 160 ( 161 HTTPS; 162 $service_fn: ident 163 ) => {{ 164 let mut port = 10000; 165 let listener = loop { 166 let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port)); 167 match tokio::net::TcpListener::bind(addr).await { 168 Ok(listener) => break listener, 169 Err(_) => { 170 port += 1; 171 if port == u16::MAX { 172 port = 10000; 173 } 174 continue; 175 } 176 } 177 }; 178 let port = listener.local_addr().unwrap().port(); 179 180 tokio::spawn(async move { 181 let mut acceptor = openssl::ssl::SslAcceptor::mozilla_intermediate(openssl::ssl::SslMethod::tls()) 182 .expect("SslAcceptorBuilder error"); 183 acceptor 184 .set_session_id_context(b"test") 185 .expect("Set session id error"); 186 acceptor 187 .set_private_key_file("tests/file/key.pem", openssl::ssl::SslFiletype::PEM) 188 .expect("Set private key error"); 189 acceptor 190 .set_certificate_chain_file("tests/file/cert.pem") 191 .expect("Set cert error"); 192 acceptor.set_alpn_protos(b"\x08http/1.1").unwrap(); 193 acceptor.set_alpn_select_callback(|_, client| { 194 openssl::ssl::select_next_proto(b"\x08http/1.1", client).ok_or(openssl::ssl::AlpnError::NOACK) 195 }); 196 197 let acceptor = acceptor.build(); 198 199 let (stream, _) = listener.accept().await.expect("TCP listener accept error"); 200 let ssl = openssl::ssl::Ssl::new(acceptor.context()).expect("Ssl Error"); 201 let mut stream = tokio_openssl::SslStream::new(ssl, stream).expect("SslStream Error"); 202 core::pin::Pin::new(&mut stream).accept().await.unwrap(); // SSL negotiation finished successfully 203 204 hyper::server::conn::Http::new() 205 .http1_only(true) 206 .http1_keep_alive(true) 207 .serve_connection(stream, hyper::service::service_fn($service_fn)) 208 .await 209 }); 210 211 TlsHandle { 212 port, 213 } 214 }}; 215} 216 217/// Creates a `Request`. 218#[macro_export] 219#[cfg(feature = "sync")] 220macro_rules! ylong_request { 221 ( 222 Request: { 223 Method: $method: expr, 224 Host: $host: expr, 225 Port: $port: expr, 226 $( 227 Header: $req_n: expr, $req_v: expr, 228 )* 229 Body: $req_body: expr, 230 }, 231 ) => { 232 ylong_http::request::RequestBuilder::new() 233 .method($method) 234 .url(format!("{}:{}", $host, $port).as_str()) 235 $(.header($req_n, $req_v))* 236 .body(ylong_http::body::TextBody::from_bytes($req_body.as_bytes())) 237 .expect("Request build failed") 238 }; 239} 240 241/// Creates a `Request`. 242#[macro_export] 243#[cfg(feature = "async")] 244macro_rules! ylong_request { 245 ( 246 Request: { 247 Method: $method: expr, 248 Host: $host: expr, 249 Port: $port: expr, 250 $( 251 Header: $req_n: expr, $req_v: expr, 252 )* 253 Body: $req_body: expr, 254 }, 255 ) => { 256 ylong_http_client::async_impl::RequestBuilder::new() 257 .method($method) 258 .url(format!("{}:{}", $host, $port).as_str()) 259 $(.header($req_n, $req_v))* 260 .body(ylong_http_client::async_impl::Body::slice($req_body.as_bytes())) 261 .expect("Request build failed") 262 }; 263} 264 265/// Sets server async function. 266#[macro_export] 267macro_rules! set_server_fn { 268 ( 269 ASYNC; 270 $server_fn_name: ident, 271 $(Request: { 272 Method: $method: expr, 273 $( 274 Header: $req_n: expr, $req_v: expr, 275 )* 276 Body: $req_body: expr, 277 }, 278 Response: { 279 Status: $status: expr, 280 Version: $version: expr, 281 $( 282 Header: $resp_n: expr, $resp_v: expr, 283 )* 284 Body: $resp_body: expr, 285 },)* 286 ) => { 287 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> { 288 match request.method().as_str() { 289 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches. 290 $( 291 $method => { 292 assert_eq!($method, request.method().as_str(), "Assert request method failed"); 293 assert_eq!( 294 "/", 295 request.uri().to_string(), 296 "Assert request host failed", 297 ); 298 assert_eq!( 299 $version, 300 format!("{:?}", request.version()), 301 "Assert request version failed", 302 ); 303 $(assert_eq!( 304 $req_v, 305 request 306 .headers() 307 .get($req_n) 308 .expect(format!("Get request header \"{}\" failed", $req_n).as_str()) 309 .to_str() 310 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()), 311 "Assert request header {} failed", $req_n, 312 );)* 313 let body = hyper::body::to_bytes(request.into_body()).await 314 .expect("Get request body failed"); 315 assert_eq!($req_body.as_bytes(), body, "Assert request body failed"); 316 Ok( 317 hyper::Response::builder() 318 .version(hyper::Version::HTTP_11) 319 .status($status) 320 $(.header($resp_n, $resp_v))* 321 .body($resp_body.into()) 322 .expect("Build response failed") 323 ) 324 }, 325 )* 326 _ => {panic!("Unrecognized METHOD !");}, 327 } 328 } 329 330 }; 331 ( 332 SYNC; 333 $server_fn_name: ident, 334 $(Request: { 335 Method: $method: expr, 336 $( 337 Header: $req_n: expr, $req_v: expr, 338 )* 339 Body: $req_body: expr, 340 }, 341 Response: { 342 Status: $status: expr, 343 Version: $version: expr, 344 $( 345 Header: $resp_n: expr, $resp_v: expr, 346 )* 347 Body: $resp_body: expr, 348 },)* 349 ) => { 350 async fn $server_fn_name(request: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, std::convert::Infallible> { 351 match request.method().as_str() { 352 // TODO If there are requests with the same Method, an error will be reported for creating two identical match branches. 353 $( 354 $method => { 355 assert_eq!($method, request.method().as_str(), "Assert request method failed"); 356 assert_eq!( 357 "/", 358 request.uri().to_string(), 359 "Assert request uri failed", 360 ); 361 assert_eq!( 362 $version, 363 format!("{:?}", request.version()), 364 "Assert request version failed", 365 ); 366 $(assert_eq!( 367 $req_v, 368 request 369 .headers() 370 .get($req_n) 371 .expect(format!("Get request header \"{}\" failed", $req_n).as_str()) 372 .to_str() 373 .expect(format!("Convert request header \"{}\" into string failed", $req_n).as_str()), 374 "Assert request header {} failed", $req_n, 375 );)* 376 let body = hyper::body::to_bytes(request.into_body()).await 377 .expect("Get request body failed"); 378 assert_eq!($req_body.as_bytes(), body, "Assert request body failed"); 379 Ok( 380 hyper::Response::builder() 381 .version(hyper::Version::HTTP_11) 382 .status($status) 383 $(.header($resp_n, $resp_v))* 384 .body($resp_body.into()) 385 .expect("Build response failed") 386 ) 387 }, 388 )* 389 _ => {panic!("Unrecognized METHOD !");}, 390 } 391 } 392 393 }; 394} 395 396#[macro_export] 397macro_rules! ensure_server_shutdown { 398 (ServerHandle: $handle:expr) => { 399 $handle 400 .client_shutdown 401 .send(()) 402 .await 403 .expect("Client channel (Server-Half) be closed unexpectedly"); 404 $handle 405 .server_shutdown 406 .recv() 407 .await 408 .expect("Server channel (Server-Half) be closed unexpectedly"); 409 }; 410} 411 412pub fn init_test_work_runtime(thread_num: usize) -> Runtime { 413 tokio::runtime::Builder::new_multi_thread() 414 .worker_threads(thread_num) 415 .enable_all() 416 .build() 417 .expect("Build runtime failed.") 418} 419