1// Copyright (c) 2023 Huawei Device Co., Ltd.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14pub(crate) trait Dispatcher {
15    type Handle;
16
17    fn dispatch(&self) -> Option<Self::Handle>;
18
19    fn is_shutdown(&self) -> bool;
20
21    #[allow(dead_code)]
22    fn is_goaway(&self) -> bool;
23}
24
25pub(crate) enum ConnDispatcher<S> {
26    #[cfg(feature = "http1_1")]
27    Http1(http1::Http1Dispatcher<S>),
28
29    #[cfg(feature = "http2")]
30    Http2(http2::Http2Dispatcher<S>),
31
32    #[cfg(feature = "http3")]
33    Http3(http3::Http3Dispatcher<S>),
34}
35
36impl<S> Dispatcher for ConnDispatcher<S> {
37    type Handle = Conn<S>;
38
39    fn dispatch(&self) -> Option<Self::Handle> {
40        match self {
41            #[cfg(feature = "http1_1")]
42            Self::Http1(h1) => h1.dispatch().map(Conn::Http1),
43
44            #[cfg(feature = "http2")]
45            Self::Http2(h2) => h2.dispatch().map(Conn::Http2),
46
47            #[cfg(feature = "http3")]
48            Self::Http3(h3) => h3.dispatch().map(Conn::Http3),
49        }
50    }
51
52    fn is_shutdown(&self) -> bool {
53        match self {
54            #[cfg(feature = "http1_1")]
55            Self::Http1(h1) => h1.is_shutdown(),
56
57            #[cfg(feature = "http2")]
58            Self::Http2(h2) => h2.is_shutdown(),
59
60            #[cfg(feature = "http3")]
61            Self::Http3(h3) => h3.is_shutdown(),
62        }
63    }
64
65    fn is_goaway(&self) -> bool {
66        match self {
67            #[cfg(feature = "http1_1")]
68            Self::Http1(h1) => h1.is_goaway(),
69
70            #[cfg(feature = "http2")]
71            Self::Http2(h2) => h2.is_goaway(),
72
73            #[cfg(feature = "http3")]
74            Self::Http3(h3) => h3.is_goaway(),
75        }
76    }
77}
78
79pub(crate) enum Conn<S> {
80    #[cfg(feature = "http1_1")]
81    Http1(http1::Http1Conn<S>),
82
83    #[cfg(feature = "http2")]
84    Http2(http2::Http2Conn<S>),
85
86    #[cfg(feature = "http3")]
87    Http3(http3::Http3Conn<S>),
88}
89
90#[cfg(feature = "http1_1")]
91pub(crate) mod http1 {
92    use std::cell::UnsafeCell;
93    use std::sync::atomic::{AtomicBool, Ordering};
94    use std::sync::Arc;
95
96    use super::{ConnDispatcher, Dispatcher};
97
98    impl<S> ConnDispatcher<S> {
99        pub(crate) fn http1(io: S) -> Self {
100            Self::Http1(Http1Dispatcher::new(io))
101        }
102    }
103
104    /// HTTP1-based connection manager, which can dispatch connections to other
105    /// threads according to HTTP1 syntax.
106    pub(crate) struct Http1Dispatcher<S> {
107        inner: Arc<Inner<S>>,
108    }
109
110    pub(crate) struct Inner<S> {
111        pub(crate) io: UnsafeCell<S>,
112        // `occupied` indicates that the connection is occupied. Only one coroutine
113        // can get the handle at the same time. Once the handle is fetched, the flag
114        // position is true.
115        pub(crate) occupied: AtomicBool,
116        // `shutdown` indicates that the connection need to be shut down.
117        pub(crate) shutdown: AtomicBool,
118    }
119
120    unsafe impl<S> Sync for Inner<S> {}
121
122    impl<S> Http1Dispatcher<S> {
123        pub(crate) fn new(io: S) -> Self {
124            Self {
125                inner: Arc::new(Inner {
126                    io: UnsafeCell::new(io),
127                    occupied: AtomicBool::new(false),
128                    shutdown: AtomicBool::new(false),
129                }),
130            }
131        }
132    }
133
134    impl<S> Dispatcher for Http1Dispatcher<S> {
135        type Handle = Http1Conn<S>;
136
137        fn dispatch(&self) -> Option<Self::Handle> {
138            self.inner
139                .occupied
140                .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
141                .ok()
142                .map(|_| Http1Conn {
143                    inner: self.inner.clone(),
144                })
145        }
146
147        fn is_shutdown(&self) -> bool {
148            self.inner.shutdown.load(Ordering::Relaxed)
149        }
150
151        fn is_goaway(&self) -> bool {
152            false
153        }
154    }
155
156    /// Handle returned to other threads for I/O operations.
157    pub(crate) struct Http1Conn<S> {
158        pub(crate) inner: Arc<Inner<S>>,
159    }
160
161    impl<S> Http1Conn<S> {
162        pub(crate) fn raw_mut(&mut self) -> &mut S {
163            // SAFETY: In the case of `HTTP1`, only one coroutine gets the handle
164            // at the same time.
165            unsafe { &mut *self.inner.io.get() }
166        }
167
168        pub(crate) fn shutdown(&self) {
169            self.inner.shutdown.store(true, Ordering::Release);
170        }
171    }
172
173    impl<S> Drop for Http1Conn<S> {
174        fn drop(&mut self) {
175            self.inner.occupied.store(false, Ordering::Release)
176        }
177    }
178}
179
180#[cfg(feature = "http2")]
181pub(crate) mod http2 {
182    use std::collections::HashMap;
183    use std::future::Future;
184    use std::marker::PhantomData;
185    use std::pin::Pin;
186    use std::sync::atomic::{AtomicBool, Ordering};
187    use std::sync::{Arc, Mutex};
188    use std::task::{Context, Poll};
189
190    use ylong_http::error::HttpError;
191    use ylong_http::h2::{
192        ErrorCode, Frame, FrameDecoder, FrameEncoder, FrameFlags, Goaway, H2Error, Payload,
193        RstStream, Settings, SettingsBuilder, StreamId,
194    };
195
196    use crate::runtime::{
197        bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, AsyncWriteExt, BoundedReceiver,
198        BoundedSender, SendError, UnboundedReceiver, UnboundedSender, WriteHalf,
199    };
200    use crate::util::config::H2Config;
201    use crate::util::dispatcher::{ConnDispatcher, Dispatcher};
202    use crate::util::h2::{
203        ConnManager, FlowControl, H2StreamState, RecvData, RequestWrapper, SendData,
204        StreamEndState, Streams,
205    };
206    use crate::ErrorKind::Request;
207    use crate::{ErrorKind, HttpClientError};
208    const DEFAULT_MAX_FRAME_SIZE: usize = 2 << 13;
209    const DEFAULT_WINDOW_SIZE: u32 = 65535;
210
211    pub(crate) type ManagerSendFut =
212        Pin<Box<dyn Future<Output = Result<(), SendError<RespMessage>>> + Send + Sync>>;
213
214    pub(crate) enum RespMessage {
215        Output(Frame),
216        OutputExit(DispatchErrorKind),
217    }
218
219    pub(crate) enum OutputMessage {
220        Output(Frame),
221        OutputExit(DispatchErrorKind),
222    }
223
224    pub(crate) struct ReqMessage {
225        pub(crate) sender: BoundedSender<RespMessage>,
226        pub(crate) request: RequestWrapper,
227    }
228
229    #[derive(Debug, Eq, PartialEq, Copy, Clone)]
230    pub(crate) enum DispatchErrorKind {
231        H2(H2Error),
232        Io(std::io::ErrorKind),
233        ChannelClosed,
234        Disconnect,
235    }
236
237    // HTTP2-based connection manager, which can dispatch connections to other
238    // threads according to HTTP2 syntax.
239    pub(crate) struct Http2Dispatcher<S> {
240        pub(crate) allowed_cache: usize,
241        pub(crate) sender: UnboundedSender<ReqMessage>,
242        pub(crate) io_shutdown: Arc<AtomicBool>,
243        pub(crate) handles: Vec<crate::runtime::JoinHandle<()>>,
244        pub(crate) _mark: PhantomData<S>,
245    }
246
247    pub(crate) struct Http2Conn<S> {
248        pub(crate) allow_cached_frames: usize,
249        // Sends frame to StreamController
250        pub(crate) sender: UnboundedSender<ReqMessage>,
251        pub(crate) receiver: RespReceiver,
252        pub(crate) io_shutdown: Arc<AtomicBool>,
253        pub(crate) _mark: PhantomData<S>,
254    }
255
256    pub(crate) struct StreamController {
257        // The connection close flag organizes new stream commits to the current connection when
258        // closed.
259        pub(crate) io_shutdown: Arc<AtomicBool>,
260        // The senders of all connected stream channels of response.
261        pub(crate) senders: HashMap<StreamId, BoundedSender<RespMessage>>,
262        pub(crate) curr_message: HashMap<StreamId, ManagerSendFut>,
263        // Stream information on the connection.
264        pub(crate) streams: Streams,
265        // Received GO_AWAY frame.
266        pub(crate) recved_go_away: Option<StreamId>,
267        // The last GO_AWAY frame sent by the client.
268        pub(crate) go_away_sync: GoAwaySync,
269    }
270
271    #[derive(Default)]
272    pub(crate) struct GoAwaySync {
273        pub(crate) going_away: Option<Goaway>,
274    }
275
276    #[derive(Default)]
277    pub(crate) struct SettingsSync {
278        pub(crate) settings: SettingsState,
279    }
280
281    #[derive(Default, Clone)]
282    pub(crate) enum SettingsState {
283        Acknowledging(Settings),
284        #[default]
285        Synced,
286    }
287
288    #[derive(Default)]
289    pub(crate) struct RespReceiver {
290        receiver: Option<BoundedReceiver<RespMessage>>,
291    }
292
293    impl<S> ConnDispatcher<S>
294    where
295        S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
296    {
297        pub(crate) fn http2(config: H2Config, io: S) -> Self {
298            Self::Http2(Http2Dispatcher::new(config, io))
299        }
300    }
301
302    impl<S> Http2Dispatcher<S>
303    where
304        S: AsyncRead + AsyncWrite + Sync + Send + Unpin + 'static,
305    {
306        pub(crate) fn new(config: H2Config, io: S) -> Self {
307            let settings = create_initial_settings(&config);
308
309            let mut flow = FlowControl::new(DEFAULT_WINDOW_SIZE, DEFAULT_WINDOW_SIZE);
310            flow.setup_recv_window(config.conn_window_size());
311
312            let streams = Streams::new(config.stream_window_size(), DEFAULT_WINDOW_SIZE, flow);
313            let shutdown_flag = Arc::new(AtomicBool::new(false));
314            let controller = StreamController::new(streams, shutdown_flag.clone());
315
316            let (input_tx, input_rx) = unbounded_channel();
317            let (req_tx, req_rx) = unbounded_channel();
318
319            // Error is not possible, so it is not handled for the time
320            // being.
321            let mut handles = Vec::with_capacity(3);
322            if input_tx.send(settings).is_ok() {
323                Self::launch(
324                    config.allowed_cache_frame_size(),
325                    config.use_huffman_coding(),
326                    controller,
327                    (input_tx, input_rx),
328                    req_rx,
329                    &mut handles,
330                    io,
331                );
332            }
333            Self {
334                allowed_cache: config.allowed_cache_frame_size(),
335                sender: req_tx,
336                io_shutdown: shutdown_flag,
337                handles,
338                _mark: PhantomData,
339            }
340        }
341
342        fn launch(
343            allow_num: usize,
344            use_huffman: bool,
345            controller: StreamController,
346            input_channel: (UnboundedSender<Frame>, UnboundedReceiver<Frame>),
347            req_rx: UnboundedReceiver<ReqMessage>,
348            handles: &mut Vec<crate::runtime::JoinHandle<()>>,
349            io: S,
350        ) {
351            let (resp_tx, resp_rx) = bounded_channel(allow_num);
352            let (read, write) = crate::runtime::split(io);
353            let settings_sync = Arc::new(Mutex::new(SettingsSync::default()));
354            let send_settings_sync = settings_sync.clone();
355            let send = crate::runtime::spawn(async move {
356                let mut writer = write;
357                if async_send_preface(&mut writer).await.is_ok() {
358                    let encoder = FrameEncoder::new(DEFAULT_MAX_FRAME_SIZE, use_huffman);
359                    let mut send =
360                        SendData::new(encoder, send_settings_sync, writer, input_channel.1);
361                    let _ = Pin::new(&mut send).await;
362                }
363            });
364            handles.push(send);
365
366            let recv_settings_sync = settings_sync.clone();
367            let recv = crate::runtime::spawn(async move {
368                let decoder = FrameDecoder::new();
369                let mut recv = RecvData::new(decoder, recv_settings_sync, read, resp_tx);
370                let _ = Pin::new(&mut recv).await;
371            });
372            handles.push(recv);
373
374            let manager = crate::runtime::spawn(async move {
375                let mut conn_manager =
376                    ConnManager::new(settings_sync, input_channel.0, resp_rx, req_rx, controller);
377                let _ = Pin::new(&mut conn_manager).await;
378            });
379            handles.push(manager);
380        }
381    }
382
383    impl<S> Dispatcher for Http2Dispatcher<S> {
384        type Handle = Http2Conn<S>;
385
386        fn dispatch(&self) -> Option<Self::Handle> {
387            let sender = self.sender.clone();
388            let handle = Http2Conn::new(self.allowed_cache, self.io_shutdown.clone(), sender);
389            Some(handle)
390        }
391
392        fn is_shutdown(&self) -> bool {
393            self.io_shutdown.load(Ordering::Relaxed)
394        }
395
396        fn is_goaway(&self) -> bool {
397            // todo: goaway and shutdown
398            false
399        }
400    }
401
402    impl<S> Drop for Http2Dispatcher<S> {
403        fn drop(&mut self) {
404            for handle in &self.handles {
405                #[cfg(feature = "ylong_base")]
406                handle.cancel();
407                #[cfg(feature = "tokio_base")]
408                handle.abort();
409            }
410        }
411    }
412
413    impl<S> Http2Conn<S> {
414        pub(crate) fn new(
415            allow_cached_num: usize,
416            io_shutdown: Arc<AtomicBool>,
417            sender: UnboundedSender<ReqMessage>,
418        ) -> Self {
419            Self {
420                allow_cached_frames: allow_cached_num,
421                sender,
422                receiver: RespReceiver::default(),
423                io_shutdown,
424                _mark: PhantomData,
425            }
426        }
427
428        pub(crate) fn send_frame_to_controller(
429            &mut self,
430            request: RequestWrapper,
431        ) -> Result<(), HttpClientError> {
432            let (tx, rx) = bounded_channel::<RespMessage>(self.allow_cached_frames);
433            self.receiver.set_receiver(rx);
434            self.sender
435                .send(ReqMessage {
436                    sender: tx,
437                    request,
438                })
439                .map_err(|_| {
440                    HttpClientError::from_str(ErrorKind::Request, "Request Sender Closed !")
441                })
442        }
443    }
444
445    impl StreamController {
446        pub(crate) fn new(streams: Streams, shutdown: Arc<AtomicBool>) -> Self {
447            Self {
448                io_shutdown: shutdown,
449                senders: HashMap::new(),
450                curr_message: HashMap::new(),
451                streams,
452                recved_go_away: None,
453                go_away_sync: GoAwaySync::default(),
454            }
455        }
456
457        pub(crate) fn shutdown(&self) {
458            self.io_shutdown.store(true, Ordering::Release);
459        }
460
461        pub(crate) fn get_unsent_streams(
462            &mut self,
463            last_stream_id: StreamId,
464        ) -> Result<Vec<StreamId>, H2Error> {
465            // The last-stream-id in the subsequent GO_AWAY frame
466            // cannot be greater than the last-stream-id in the previous GO_AWAY frame.
467            if self.streams.max_send_id < last_stream_id {
468                return Err(H2Error::ConnectionError(ErrorCode::ProtocolError));
469            }
470            self.streams.max_send_id = last_stream_id;
471            Ok(self.streams.get_go_away_streams(last_stream_id))
472        }
473
474        pub(crate) fn send_message_to_stream(
475            &mut self,
476            cx: &mut Context<'_>,
477            stream_id: StreamId,
478            message: RespMessage,
479        ) -> Poll<Result<(), H2Error>> {
480            if let Some(sender) = self.senders.get(&stream_id) {
481                // If the client coroutine has exited, this frame is skipped.
482                let mut tx = {
483                    let sender = sender.clone();
484                    let ft = async move { sender.send(message).await };
485                    Box::pin(ft)
486                };
487
488                match tx.as_mut().poll(cx) {
489                    Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
490                    // The current coroutine sending the request exited prematurely.
491                    Poll::Ready(Err(_)) => {
492                        self.senders.remove(&stream_id);
493                        Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError)))
494                    }
495                    Poll::Pending => {
496                        self.curr_message.insert(stream_id, tx);
497                        Poll::Pending
498                    }
499                }
500            } else {
501                Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError)))
502            }
503        }
504
505        pub(crate) fn poll_blocked_message(
506            &mut self,
507            cx: &mut Context<'_>,
508            input_tx: &UnboundedSender<Frame>,
509        ) -> Poll<()> {
510            let keys: Vec<StreamId> = self.curr_message.keys().cloned().collect();
511            let mut blocked = false;
512
513            for key in keys {
514                if let Some(mut task) = self.curr_message.remove(&key) {
515                    match task.as_mut().poll(cx) {
516                        Poll::Ready(Ok(_)) => {}
517                        // The current coroutine sending the request exited prematurely.
518                        Poll::Ready(Err(_)) => {
519                            self.senders.remove(&key);
520                            if let Some(state) = self.streams.stream_state(key) {
521                                if !matches!(state, H2StreamState::Closed(_)) {
522                                    if let StreamEndState::OK = self.streams.send_local_reset(key) {
523                                        let rest_payload =
524                                            RstStream::new(ErrorCode::NoError.into_code());
525                                        let frame = Frame::new(
526                                            key,
527                                            FrameFlags::empty(),
528                                            Payload::RstStream(rest_payload),
529                                        );
530                                        // ignore the send error occurs here in order to finish all
531                                        // tasks.
532                                        let _ = input_tx.send(frame);
533                                    }
534                                }
535                            }
536                        }
537                        Poll::Pending => {
538                            self.curr_message.insert(key, task);
539                            blocked = true;
540                        }
541                    }
542                }
543            }
544            if blocked {
545                Poll::Pending
546            } else {
547                Poll::Ready(())
548            }
549        }
550    }
551
552    impl RespReceiver {
553        pub(crate) fn set_receiver(&mut self, receiver: BoundedReceiver<RespMessage>) {
554            self.receiver = Some(receiver);
555        }
556
557        pub(crate) async fn recv(&mut self) -> Result<Frame, HttpClientError> {
558            match self.receiver {
559                Some(ref mut receiver) => {
560                    #[cfg(feature = "tokio_base")]
561                    match receiver.recv().await {
562                        None => err_from_msg!(Request, "Response Receiver Closed !"),
563                        Some(message) => match message {
564                            RespMessage::Output(frame) => Ok(frame),
565                            RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
566                        },
567                    }
568
569                    #[cfg(feature = "ylong_base")]
570                    match receiver.recv().await {
571                        Err(err) => Err(HttpClientError::from_error(ErrorKind::Request, err)),
572                        Ok(message) => match message {
573                            RespMessage::Output(frame) => Ok(frame),
574                            RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
575                        },
576                    }
577                }
578                // this will not happen.
579                None => Err(HttpClientError::from_str(
580                    ErrorKind::Request,
581                    "Invalid Frame Receiver !",
582                )),
583            }
584        }
585
586        pub(crate) fn poll_recv(
587            &mut self,
588            cx: &mut Context<'_>,
589        ) -> Poll<Result<Frame, HttpClientError>> {
590            if let Some(ref mut receiver) = self.receiver {
591                #[cfg(feature = "tokio_base")]
592                match receiver.poll_recv(cx) {
593                    Poll::Ready(None) => {
594                        Poll::Ready(err_from_msg!(Request, "Error receive response !"))
595                    }
596                    Poll::Ready(Some(message)) => match message {
597                        RespMessage::Output(frame) => Poll::Ready(Ok(frame)),
598                        RespMessage::OutputExit(e) => Poll::Ready(Err(dispatch_client_error(e))),
599                    },
600                    Poll::Pending => Poll::Pending,
601                }
602
603                #[cfg(feature = "ylong_base")]
604                match receiver.poll_recv(cx) {
605                    Poll::Ready(Err(e)) => {
606                        Poll::Ready(Err(HttpClientError::from_error(ErrorKind::Request, e)))
607                    }
608                    Poll::Ready(Ok(message)) => match message {
609                        RespMessage::Output(frame) => Poll::Ready(Ok(frame)),
610                        RespMessage::OutputExit(e) => Poll::Ready(Err(dispatch_client_error(e))),
611                    },
612                    Poll::Pending => Poll::Pending,
613                }
614            } else {
615                Poll::Ready(err_from_msg!(Request, "Invalid Frame Receiver !"))
616            }
617        }
618    }
619
620    async fn async_send_preface<S>(writer: &mut WriteHalf<S>) -> Result<(), DispatchErrorKind>
621    where
622        S: AsyncWrite + Unpin,
623    {
624        const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
625        writer
626            .write_all(PREFACE)
627            .await
628            .map_err(|e| DispatchErrorKind::Io(e.kind()))
629    }
630
631    pub(crate) fn create_initial_settings(config: &H2Config) -> Frame {
632        let settings = SettingsBuilder::new()
633            .max_header_list_size(config.max_header_list_size())
634            .max_frame_size(config.max_frame_size())
635            .header_table_size(config.header_table_size())
636            .enable_push(config.enable_push())
637            .initial_window_size(config.stream_window_size())
638            .build();
639
640        Frame::new(0, FrameFlags::new(0), Payload::Settings(settings))
641    }
642
643    impl From<std::io::Error> for DispatchErrorKind {
644        fn from(value: std::io::Error) -> Self {
645            DispatchErrorKind::Io(value.kind())
646        }
647    }
648
649    impl From<H2Error> for DispatchErrorKind {
650        fn from(err: H2Error) -> Self {
651            DispatchErrorKind::H2(err)
652        }
653    }
654
655    pub(crate) fn dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError {
656        match dispatch_error {
657            DispatchErrorKind::H2(e) => HttpClientError::from_error(Request, HttpError::from(e)),
658            DispatchErrorKind::Io(e) => {
659                HttpClientError::from_io_error(Request, std::io::Error::from(e))
660            }
661            DispatchErrorKind::ChannelClosed => {
662                HttpClientError::from_str(Request, "Coroutine channel closed.")
663            }
664            DispatchErrorKind::Disconnect => {
665                HttpClientError::from_str(Request, "remote peer closed.")
666            }
667        }
668    }
669}
670
671#[cfg(feature = "http3")]
672pub(crate) mod http3 {
673    use std::marker::PhantomData;
674    use std::pin::Pin;
675    use std::sync::atomic::{AtomicBool, Ordering};
676    use std::sync::{Arc, Mutex};
677
678    use ylong_http::error::HttpError;
679    use ylong_http::h3::{Frame, FrameDecoder, H3Error};
680
681    use crate::async_impl::{ConnInfo, QuicConn};
682    use crate::runtime::{
683        bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, BoundedReceiver, BoundedSender,
684        UnboundedSender,
685    };
686    use crate::util::config::H3Config;
687    use crate::util::data_ref::BodyDataRef;
688    use crate::util::dispatcher::{ConnDispatcher, Dispatcher};
689    use crate::util::h3::io_manager::IOManager;
690    use crate::util::h3::stream_manager::StreamManager;
691    use crate::ErrorKind::Request;
692    use crate::{ErrorKind, HttpClientError};
693
694    pub(crate) struct Http3Dispatcher<S> {
695        pub(crate) req_tx: UnboundedSender<ReqMessage>,
696        pub(crate) handles: Vec<crate::runtime::JoinHandle<()>>,
697        pub(crate) _mark: PhantomData<S>,
698        pub(crate) io_shutdown: Arc<AtomicBool>,
699        pub(crate) io_goaway: Arc<AtomicBool>,
700    }
701
702    pub(crate) struct Http3Conn<S> {
703        pub(crate) sender: UnboundedSender<ReqMessage>,
704        pub(crate) resp_receiver: BoundedReceiver<RespMessage>,
705        pub(crate) resp_sender: BoundedSender<RespMessage>,
706        pub(crate) io_shutdown: Arc<AtomicBool>,
707        pub(crate) _mark: PhantomData<S>,
708    }
709
710    pub(crate) struct RequestWrapper {
711        pub(crate) header: Frame,
712        pub(crate) data: BodyDataRef,
713    }
714
715    #[derive(Debug, Clone)]
716    pub(crate) enum DispatchErrorKind {
717        H3(H3Error),
718        Io(std::io::ErrorKind),
719        Quic(quiche::Error),
720        ChannelClosed,
721        StreamFinished,
722        // todo: retry?
723        GoawayReceived,
724        Disconnect,
725    }
726
727    pub(crate) enum RespMessage {
728        Output(Frame),
729        OutputExit(DispatchErrorKind),
730    }
731
732    pub(crate) struct ReqMessage {
733        pub(crate) request: RequestWrapper,
734        pub(crate) frame_tx: BoundedSender<RespMessage>,
735    }
736
737    impl<S> Http3Dispatcher<S>
738    where
739        S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static,
740    {
741        pub(crate) fn new(config: H3Config, io: S, quic_connection: QuicConn) -> Self {
742            let (req_tx, req_rx) = unbounded_channel();
743            let (io_manager_tx, io_manager_rx) = unbounded_channel();
744            let (stream_manager_tx, stream_manager_rx) = unbounded_channel();
745            let mut handles = Vec::with_capacity(2);
746            let conn = Arc::new(Mutex::new(quic_connection));
747            let io_shutdown = Arc::new(AtomicBool::new(false));
748            let io_goaway = Arc::new(AtomicBool::new(false));
749            let mut stream_manager = StreamManager::new(
750                conn.clone(),
751                io_manager_tx,
752                stream_manager_rx,
753                req_rx,
754                FrameDecoder::new(
755                    config.qpack_blocked_streams() as usize,
756                    config.qpack_max_table_capacity() as usize,
757                ),
758                io_shutdown.clone(),
759                io_goaway.clone(),
760            );
761            let stream_handle = crate::runtime::spawn(async move {
762                if stream_manager.init(config).is_err() {
763                    return;
764                }
765                let _ = Pin::new(&mut stream_manager).await;
766            });
767            handles.push(stream_handle);
768
769            let io_handle = crate::runtime::spawn(async move {
770                let mut io_manager = IOManager::new(io, conn, io_manager_rx, stream_manager_tx);
771                let _ = Pin::new(&mut io_manager).await;
772            });
773            handles.push(io_handle);
774            // read_rx gets readable stream ids and writable client channels, then read
775            // stream and send to the corresponding channel
776            Self {
777                req_tx,
778                handles,
779                _mark: PhantomData,
780                io_shutdown,
781                io_goaway,
782            }
783        }
784    }
785
786    impl<S> Http3Conn<S> {
787        pub(crate) fn new(
788            sender: UnboundedSender<ReqMessage>,
789            io_shutdown: Arc<AtomicBool>,
790        ) -> Self {
791            const CHANNEL_SIZE: usize = 3;
792            let (resp_sender, resp_receiver) = bounded_channel(CHANNEL_SIZE);
793            Self {
794                sender,
795                resp_sender,
796                resp_receiver,
797                _mark: PhantomData,
798                io_shutdown,
799            }
800        }
801
802        pub(crate) fn send_frame_to_reader(
803            &mut self,
804            request: RequestWrapper,
805        ) -> Result<(), HttpClientError> {
806            self.sender
807                .send(ReqMessage {
808                    request,
809                    frame_tx: self.resp_sender.clone(),
810                })
811                .map_err(|_| {
812                    HttpClientError::from_str(ErrorKind::Request, "Request Sender Closed !")
813                })
814        }
815
816        pub(crate) async fn recv_resp(&mut self) -> Result<Frame, HttpClientError> {
817            #[cfg(feature = "tokio_base")]
818            match self.resp_receiver.recv().await {
819                None => err_from_msg!(Request, "Response Receiver Closed !"),
820                Some(message) => match message {
821                    RespMessage::Output(frame) => Ok(frame),
822                    RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
823                },
824            }
825
826            #[cfg(feature = "ylong_base")]
827            match self.resp_receiver.recv().await {
828                Err(err) => Err(HttpClientError::from_error(ErrorKind::Request, err)),
829                Ok(message) => match message {
830                    RespMessage::Output(frame) => Ok(frame),
831                    RespMessage::OutputExit(e) => Err(dispatch_client_error(e)),
832                },
833            }
834        }
835    }
836
837    impl<S> ConnDispatcher<S>
838    where
839        S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static,
840    {
841        pub(crate) fn http3(config: H3Config, io: S, quic_connection: QuicConn) -> Self {
842            Self::Http3(Http3Dispatcher::new(config, io, quic_connection))
843        }
844    }
845
846    impl<S> Dispatcher for Http3Dispatcher<S> {
847        type Handle = Http3Conn<S>;
848
849        fn dispatch(&self) -> Option<Self::Handle> {
850            let sender = self.req_tx.clone();
851            Some(Http3Conn::new(sender, self.io_shutdown.clone()))
852        }
853
854        fn is_shutdown(&self) -> bool {
855            self.io_shutdown.load(Ordering::Relaxed)
856        }
857
858        fn is_goaway(&self) -> bool {
859            self.io_goaway.load(Ordering::Relaxed)
860        }
861    }
862
863    impl<S> Drop for Http3Dispatcher<S> {
864        fn drop(&mut self) {
865            for handle in &self.handles {
866                #[cfg(feature = "tokio_base")]
867                handle.abort();
868                #[cfg(feature = "ylong_base")]
869                handle.cancel();
870            }
871        }
872    }
873
874    impl From<std::io::Error> for DispatchErrorKind {
875        fn from(value: std::io::Error) -> Self {
876            DispatchErrorKind::Io(value.kind())
877        }
878    }
879
880    impl From<H3Error> for DispatchErrorKind {
881        fn from(err: H3Error) -> Self {
882            DispatchErrorKind::H3(err)
883        }
884    }
885
886    impl From<quiche::Error> for DispatchErrorKind {
887        fn from(value: quiche::Error) -> Self {
888            DispatchErrorKind::Quic(value)
889        }
890    }
891
892    pub(crate) fn dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError {
893        match dispatch_error {
894            DispatchErrorKind::H3(e) => HttpClientError::from_error(Request, HttpError::from(e)),
895            DispatchErrorKind::Io(e) => {
896                HttpClientError::from_io_error(Request, std::io::Error::from(e))
897            }
898            DispatchErrorKind::ChannelClosed => {
899                HttpClientError::from_str(Request, "Coroutine channel closed.")
900            }
901            DispatchErrorKind::Quic(e) => HttpClientError::from_error(Request, e),
902            DispatchErrorKind::GoawayReceived => {
903                HttpClientError::from_str(Request, "received remote goaway.")
904            }
905            DispatchErrorKind::StreamFinished => {
906                HttpClientError::from_str(Request, "stream finished.")
907            }
908            DispatchErrorKind::Disconnect => {
909                HttpClientError::from_str(Request, "remote peer closed.")
910            }
911        }
912    }
913}
914
915#[cfg(test)]
916mod ut_dispatch {
917    use crate::dispatcher::{ConnDispatcher, Dispatcher};
918
919    /// UT test cases for `ConnDispatcher::is_shutdown`.
920    ///
921    /// # Brief
922    /// 1. Creates a `ConnDispatcher`.
923    /// 2. Calls `ConnDispatcher::is_shutdown` to get the result.
924    /// 3. Calls `ConnDispatcher::dispatch` to get the result.
925    /// 4. Checks if the result is false.
926    #[test]
927    fn ut_is_shutdown() {
928        let conn = ConnDispatcher::http1(b"Data");
929        let res = conn.is_shutdown();
930        assert!(!res);
931        let res = conn.dispatch();
932        assert!(res.is_some());
933    }
934}
935