1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::ops::Deref;
16 use std::pin::Pin;
17 use std::sync::atomic::Ordering;
18 use std::task::{Context, Poll};
19 
20 use ylong_http::error::HttpError;
21 use ylong_http::h3::{
22     Frame, H3Error, H3ErrorCode, Headers, Parts, Payload, PseudoHeaders, HEADERS_FRAME_TYPE,
23 };
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::RequestPart;
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::ResponsePart;
28 use ylong_runtime::io::ReadBuf;
29 
30 use crate::async_impl::conn::StreamData;
31 use crate::async_impl::request::Message;
32 use crate::async_impl::{HttpBody, Response};
33 use crate::runtime::AsyncRead;
34 use crate::util::data_ref::BodyDataRef;
35 use crate::util::dispatcher::http3::{DispatchErrorKind, Http3Conn, RequestWrapper, RespMessage};
36 use crate::util::normalizer::BodyLengthParser;
37 use crate::{ErrorKind, HttpClientError};
38 
39 pub(crate) async fn request<S>(
40     mut conn: Http3Conn<S>,
41     mut message: Message,
42 ) -> Result<Response, HttpClientError>
43 where
44     S: Sync + Send + Unpin + 'static,
45 {
46     message
47         .interceptor
48         .intercept_request(message.request.ref_mut())?;
49     let part = message.request.ref_mut().part().clone();
50 
51     // TODO Implement trailer.
52     let headers = build_headers_frame(part)
53         .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
54     let data = BodyDataRef::new(message.request.clone());
55     let stream = RequestWrapper {
56         header: headers,
57         data,
58     };
59     conn.send_frame_to_reader(stream)?;
60     let frame = conn.recv_resp().await?;
61     frame_2_response(conn, frame, message)
62 }
63 
64 pub(crate) fn build_headers_frame(mut part: RequestPart) -> Result<Frame, HttpError> {
65     // todo: check rfc to see if any headers should be removed
66     let pseudo = build_pseudo_headers(&mut part)?;
67     let mut header_part = Parts::new();
68     header_part.set_header_lines(part.headers);
69     header_part.set_pseudo(pseudo);
70     let headers_payload = Headers::new(header_part);
71 
72     Ok(Frame::new(
73         HEADERS_FRAME_TYPE,
74         Payload::Headers(headers_payload),
75     ))
76 }
77 
78 // todo: error if headers not enough, should meet rfc
build_pseudo_headersnull79 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
80     let mut pseudo = PseudoHeaders::default();
81     match request_part.uri.scheme() {
82         Some(scheme) => {
83             pseudo.set_scheme(Some(String::from(scheme.as_str())));
84         }
85         None => pseudo.set_scheme(Some(String::from(Scheme::HTTPS.as_str()))),
86     }
87     pseudo.set_method(Some(String::from(request_part.method.as_str())));
88     pseudo.set_path(
89         request_part
90             .uri
91             .path_and_query()
92             .or_else(|| Some(String::from("/"))),
93     );
94     let host = request_part
95         .headers
96         .remove("host")
97         .and_then(|auth| auth.to_string().ok());
98     pseudo.set_authority(host);
99     Ok(pseudo)
100 }
101 
frame_2_responsenull102 fn frame_2_response<S>(
103     conn: Http3Conn<S>,
104     headers_frame: Frame,
105     mut message: Message,
106 ) -> Result<Response, HttpClientError>
107 where
108     S: Sync + Send + Unpin + 'static,
109 {
110     let part = match headers_frame.payload() {
111         Payload::Headers(headers) => {
112             let part = headers.get_part();
113             let (pseudo, fields) = part.parts();
114             let status_code = match pseudo.status() {
115                 Some(status) => StatusCode::from_bytes(status.as_bytes())
116                     .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
117                 None => {
118                     return Err(HttpClientError::from_str(
119                         ErrorKind::Request,
120                         "status code not found",
121                     ));
122                 }
123             };
124             ResponsePart {
125                 version: ylong_http::version::Version::HTTP3,
126                 status: status_code,
127                 headers: fields.clone(),
128             }
129         }
130         Payload::PushPromise(_) => {
131             todo!();
132         }
133         _ => {
134             return Err(HttpClientError::from_str(ErrorKind::Request, "bad frame"));
135         }
136     };
137 
138     let data_io = TextIo::new(conn);
139     let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
140         Ok(length) => length,
141         Err(e) => {
142             return Err(e);
143         }
144     };
145     let body = HttpBody::new(message.interceptor, length, Box::new(data_io), &[0u8; 0])?;
146 
147     Ok(Response::new(
148         ylong_http::response::Response::from_raw_parts(part, body),
149     ))
150 }
151 
152 struct TextIo<S> {
153     pub(crate) handle: Http3Conn<S>,
154     pub(crate) offset: usize,
155     pub(crate) remain: Option<Frame>,
156     pub(crate) is_closed: bool,
157 }
158 
159 struct HttpReadBuf<'a, 'b> {
160     buf: &'a mut ReadBuf<'b>,
161 }
162 
163 impl<'a, 'b> HttpReadBuf<'a, 'b> {
164     pub(crate) fn append_slice(&mut self, buf: &[u8]) {
165         #[cfg(feature = "ylong_base")]
166         self.buf.append(buf);
167 
168         #[cfg(feature = "tokio_base")]
169         self.buf.put_slice(buf);
170     }
171 }
172 
173 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
174     type Target = ReadBuf<'b>;
175 
derefnull176     fn deref(&self) -> &Self::Target {
177         self.buf
178     }
179 }
180 
181 impl<S> TextIo<S>
182 where
183     S: Sync + Send + Unpin + 'static,
184 {
185     pub(crate) fn new(handle: Http3Conn<S>) -> Self {
186         Self {
187             handle,
188             offset: 0,
189             remain: None,
190             is_closed: false,
191         }
192     }
193 
match_channel_messagenull194     fn match_channel_message(
195         poll_result: Poll<RespMessage>,
196         text_io: &mut TextIo<S>,
197         buf: &mut HttpReadBuf,
198     ) -> Option<Poll<std::io::Result<()>>> {
199         match poll_result {
200             Poll::Ready(RespMessage::Output(frame)) => match frame.payload() {
201                 Payload::Headers(_) => {
202                     text_io.remain = Some(frame);
203                     text_io.offset = 0;
204                     Some(Poll::Ready(Ok(())))
205                 }
206                 Payload::Data(data) => {
207                     let data = data.data();
208                     let unfilled_len = buf.remaining();
209                     let data_len = data.len();
210                     let fill_len = min(data_len, unfilled_len);
211                     if unfilled_len < data_len {
212                         buf.append_slice(&data[..fill_len]);
213                         text_io.offset += fill_len;
214                         text_io.remain = Some(frame);
215                         Some(Poll::Ready(Ok(())))
216                     } else {
217                         buf.append_slice(&data[..fill_len]);
218                         Self::end_read(text_io, data_len)
219                     }
220                 }
221                 _ => Some(Poll::Ready(Err(std::io::Error::new(
222                     std::io::ErrorKind::Other,
223                     HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
224                 )))),
225             },
226             Poll::Ready(RespMessage::OutputExit(e)) => match e {
227                 DispatchErrorKind::H3(H3Error::Connection(H3ErrorCode::H3NoError))
228                 | DispatchErrorKind::StreamFinished => {
229                     text_io.is_closed = true;
230                     Some(Poll::Ready(Ok(())))
231                 }
232                 _ => Some(Poll::Ready(Err(std::io::Error::new(
233                     std::io::ErrorKind::Other,
234                     HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
235                 )))),
236             },
237             Poll::Pending => Some(Poll::Pending),
238         }
239     }
240 
end_readnull241     fn end_read(text_io: &mut TextIo<S>, data_len: usize) -> Option<Poll<std::io::Result<()>>> {
242         text_io.offset = 0;
243         text_io.remain = None;
244         if data_len == 0 {
245             // no data read and is not end stream.
246             None
247         } else {
248             Some(Poll::Ready(Ok(())))
249         }
250     }
251 
read_remaining_datanull252     fn read_remaining_data(
253         text_io: &mut TextIo<S>,
254         buf: &mut HttpReadBuf,
255     ) -> Option<Poll<std::io::Result<()>>> {
256         if let Some(frame) = &text_io.remain {
257             return match frame.payload() {
258                 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
259                 Payload::Data(data) => {
260                     let data = data.data();
261                     let unfilled_len = buf.remaining();
262                     let data_len = data.len() - text_io.offset;
263                     let fill_len = min(unfilled_len, data_len);
264                     // The peripheral function already ensures that the remaing of buf will not be
265                     // 0.
266                     if unfilled_len < data_len {
267                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
268                         text_io.offset += fill_len;
269                         Some(Poll::Ready(Ok(())))
270                     } else {
271                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
272                         Self::end_read(text_io, data_len)
273                     }
274                 }
275                 _ => Some(Poll::Ready(Err(std::io::Error::new(
276                     std::io::ErrorKind::Other,
277                     HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
278                 )))),
279             };
280         }
281         None
282     }
283 }
284 
285 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdownnull286     fn shutdown(&self) {
287         self.handle.io_shutdown.store(true, Ordering::Relaxed);
288     }
289 
is_stream_closablenull290     fn is_stream_closable(&self) -> bool {
291         self.is_closed
292     }
293 }
294 
295 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_readnull296     fn poll_read(
297         self: Pin<&mut Self>,
298         cx: &mut Context<'_>,
299         buf: &mut ReadBuf<'_>,
300     ) -> Poll<std::io::Result<()>> {
301         let text_io = self.get_mut();
302         let mut buf = HttpReadBuf { buf };
303 
304         if buf.remaining() == 0 || text_io.is_closed {
305             return Poll::Ready(Ok(()));
306         }
307         while buf.remaining() != 0 {
308             if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
309                 return result;
310             }
311 
312             let poll_result = text_io
313                 .handle
314                 .resp_receiver
315                 .poll_recv(cx)
316                 .map_err(|_e| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?;
317 
318             if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
319                 return result;
320             }
321         }
322         Poll::Ready(Ok(()))
323     }
324 }
325