1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::cmp::min;
15 use std::ops::Deref;
16 use std::pin::Pin;
17 use std::sync::atomic::Ordering;
18 use std::task::{Context, Poll};
19 
20 use ylong_http::error::HttpError;
21 use ylong_http::h2;
22 use ylong_http::h2::{ErrorCode, Frame, FrameFlags, H2Error, Payload, PseudoHeaders};
23 use ylong_http::headers::Headers;
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::RequestPart;
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::ResponsePart;
28 
29 use crate::async_impl::conn::StreamData;
30 use crate::async_impl::request::Message;
31 use crate::async_impl::{HttpBody, Response};
32 use crate::error::{ErrorKind, HttpClientError};
33 use crate::runtime::{AsyncRead, ReadBuf};
34 use crate::util::data_ref::BodyDataRef;
35 use crate::util::dispatcher::http2::Http2Conn;
36 use crate::util::h2::RequestWrapper;
37 use crate::util::normalizer::BodyLengthParser;
38 
39 const UNUSED_FLAG: u8 = 0x0;
40 
41 pub(crate) async fn request<S>(
42     mut conn: Http2Conn<S>,
43     mut message: Message,
44 ) -> Result<Response, HttpClientError>
45 where
46     S: Sync + Send + Unpin + 'static,
47 {
48     message
49         .interceptor
50         .intercept_request(message.request.ref_mut())?;
51     let part = message.request.ref_mut().part().clone();
52 
53     // TODO Implement trailer.
54     let (flag, payload) = build_headers_payload(part, false)
55         .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
56     let data = BodyDataRef::new(message.request.clone());
57     let stream = RequestWrapper {
58         flag,
59         payload,
60         data,
61     };
62     conn.send_frame_to_controller(stream)?;
63     let frame = conn.receiver.recv().await?;
64     frame_2_response(conn, frame, message)
65 }
66 
frame_2_responsenull67 fn frame_2_response<S>(
68     conn: Http2Conn<S>,
69     headers_frame: Frame,
70     mut message: Message,
71 ) -> Result<Response, HttpClientError>
72 where
73     S: Sync + Send + Unpin + 'static,
74 {
75     let part = match headers_frame.payload() {
76         Payload::Headers(headers) => {
77             let (pseudo, fields) = headers.parts();
78             let status_code = match pseudo.status() {
79                 Some(status) => StatusCode::from_bytes(status.as_bytes())
80                     .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
81                 None => {
82                     return Err(HttpClientError::from_error(
83                         ErrorKind::Request,
84                         HttpError::from(H2Error::StreamError(
85                             headers_frame.stream_id(),
86                             ErrorCode::ProtocolError,
87                         )),
88                     ));
89                 }
90             };
91             ResponsePart {
92                 version: ylong_http::version::Version::HTTP2,
93                 status: status_code,
94                 headers: fields.clone(),
95             }
96         }
97         Payload::RstStream(reset) => {
98             return Err(HttpClientError::from_error(
99                 ErrorKind::Request,
100                 HttpError::from(H2Error::StreamError(
101                     headers_frame.stream_id(),
102                     ErrorCode::try_from(reset.error_code()).unwrap_or(ErrorCode::ProtocolError),
103                 )),
104             ));
105         }
106         _ => {
107             return Err(HttpClientError::from_error(
108                 ErrorKind::Request,
109                 HttpError::from(H2Error::StreamError(
110                     headers_frame.stream_id(),
111                     ErrorCode::ProtocolError,
112                 )),
113             ));
114         }
115     };
116 
117     let text_io = TextIo::new(conn);
118     // TODO Can http2 have no content-length header field and rely only on the
119     // end_stream flag? flag has a Body
120     let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
121         Ok(length) => length,
122         Err(e) => {
123             return Err(e);
124         }
125     };
126     let body = HttpBody::new(message.interceptor, length, Box::new(text_io), &[0u8; 0])?;
127 
128     Ok(Response::new(
129         ylong_http::response::Response::from_raw_parts(part, body),
130     ))
131 }
132 
133 pub(crate) fn build_headers_payload(
134     mut part: RequestPart,
135     is_end_stream: bool,
136 ) -> Result<(FrameFlags, Payload), HttpError> {
137     remove_connection_specific_headers(&mut part.headers)?;
138     let pseudo = build_pseudo_headers(&mut part)?;
139     let mut header_part = h2::Parts::new();
140     header_part.set_header_lines(part.headers);
141     header_part.set_pseudo(pseudo);
142     let headers_payload = h2::Headers::new(header_part);
143 
144     let mut flag = FrameFlags::new(UNUSED_FLAG);
145     flag.set_end_headers(true);
146     if is_end_stream {
147         flag.set_end_stream(true);
148     }
149     Ok((flag, Payload::Headers(headers_payload)))
150 }
151 
152 // Illegal headers validation in http2.
153 // [`Connection-Specific Headers`] implementation.
154 //
155 // [`Connection-Specific Headers`]: https://www.rfc-editor.org/rfc/rfc9113.html#name-connection-specific-header-
remove_connection_specific_headersnull156 fn remove_connection_specific_headers(headers: &mut Headers) -> Result<(), HttpError> {
157     const CONNECTION_SPECIFIC_HEADERS: &[&str; 5] = &[
158         "connection",
159         "keep-alive",
160         "proxy-connection",
161         "upgrade",
162         "transfer-encoding",
163     ];
164     for specific_header in CONNECTION_SPECIFIC_HEADERS.iter() {
165         headers.remove(*specific_header);
166     }
167 
168     if let Some(te_ref) = headers.get("te") {
169         let te = te_ref.to_string()?;
170         if te.as_str() != "trailers" {
171             headers.remove("te");
172         }
173     }
174     Ok(())
175 }
176 
build_pseudo_headersnull177 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
178     let mut pseudo = PseudoHeaders::default();
179     match request_part.uri.scheme() {
180         Some(scheme) => {
181             pseudo.set_scheme(Some(String::from(scheme.as_str())));
182         }
183         None => pseudo.set_scheme(Some(String::from(Scheme::HTTP.as_str()))),
184     }
185     pseudo.set_method(Some(String::from(request_part.method.as_str())));
186     pseudo.set_path(
187         request_part
188             .uri
189             .path_and_query()
190             .or_else(|| Some(String::from("/"))),
191     );
192     let host = request_part
193         .headers
194         .remove("host")
195         .and_then(|auth| auth.to_string().ok());
196     pseudo.set_authority(host);
197     Ok(pseudo)
198 }
199 
200 struct TextIo<S> {
201     pub(crate) handle: Http2Conn<S>,
202     pub(crate) offset: usize,
203     pub(crate) remain: Option<Frame>,
204     pub(crate) is_closed: bool,
205 }
206 
207 struct HttpReadBuf<'a, 'b> {
208     buf: &'a mut ReadBuf<'b>,
209 }
210 
211 impl<'a, 'b> HttpReadBuf<'a, 'b> {
212     pub(crate) fn append_slice(&mut self, buf: &[u8]) {
213         #[cfg(feature = "ylong_base")]
214         self.buf.append(buf);
215 
216         #[cfg(feature = "tokio_base")]
217         self.buf.put_slice(buf);
218     }
219 }
220 
221 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
222     type Target = ReadBuf<'b>;
223 
derefnull224     fn deref(&self) -> &Self::Target {
225         self.buf
226     }
227 }
228 
229 impl<S> TextIo<S>
230 where
231     S: Sync + Send + Unpin + 'static,
232 {
233     pub(crate) fn new(handle: Http2Conn<S>) -> Self {
234         Self {
235             handle,
236             offset: 0,
237             remain: None,
238             is_closed: false,
239         }
240     }
241 
match_channel_messagenull242     fn match_channel_message(
243         poll_result: Poll<Frame>,
244         text_io: &mut TextIo<S>,
245         buf: &mut HttpReadBuf,
246     ) -> Option<Poll<std::io::Result<()>>> {
247         match poll_result {
248             Poll::Ready(frame) => match frame.payload() {
249                 Payload::Headers(_) => {
250                     text_io.remain = Some(frame);
251                     text_io.offset = 0;
252                     Some(Poll::Ready(Ok(())))
253                 }
254                 Payload::Data(data) => {
255                     let data = data.data();
256                     let unfilled_len = buf.remaining();
257                     let data_len = data.len();
258                     let fill_len = min(data_len, unfilled_len);
259                     if unfilled_len < data_len {
260                         buf.append_slice(&data[..fill_len]);
261                         text_io.offset += fill_len;
262                         text_io.remain = Some(frame);
263                         Some(Poll::Ready(Ok(())))
264                     } else {
265                         buf.append_slice(&data[..fill_len]);
266                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
267                     }
268                 }
269                 Payload::RstStream(reset) => {
270                     if reset.is_no_error() {
271                         text_io.is_closed = true;
272                         Some(Poll::Ready(Ok(())))
273                     } else {
274                         Some(Poll::Ready(Err(std::io::Error::new(
275                             std::io::ErrorKind::Other,
276                             HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
277                         ))))
278                     }
279                 }
280                 _ => Some(Poll::Ready(Err(std::io::Error::new(
281                     std::io::ErrorKind::Other,
282                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
283                 )))),
284             },
285             Poll::Pending => Some(Poll::Pending),
286         }
287     }
288 
end_readnull289     fn end_read(
290         text_io: &mut TextIo<S>,
291         end_stream: bool,
292         data_len: usize,
293     ) -> Option<Poll<std::io::Result<()>>> {
294         text_io.offset = 0;
295         text_io.remain = None;
296         if end_stream {
297             text_io.is_closed = true;
298             Some(Poll::Ready(Ok(())))
299         } else if data_len == 0 {
300             // no data read and is not end stream.
301             None
302         } else {
303             Some(Poll::Ready(Ok(())))
304         }
305     }
306 
read_remaining_datanull307     fn read_remaining_data(
308         text_io: &mut TextIo<S>,
309         buf: &mut HttpReadBuf,
310     ) -> Option<Poll<std::io::Result<()>>> {
311         if let Some(frame) = &text_io.remain {
312             return match frame.payload() {
313                 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
314                 Payload::Data(data) => {
315                     let data = data.data();
316                     let unfilled_len = buf.remaining();
317                     let data_len = data.len() - text_io.offset;
318                     let fill_len = min(unfilled_len, data_len);
319                     // The peripheral function already ensures that the remaing of buf will not be
320                     // 0.
321                     if unfilled_len < data_len {
322                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
323                         text_io.offset += fill_len;
324                         Some(Poll::Ready(Ok(())))
325                     } else {
326                         buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
327                         Self::end_read(text_io, frame.flags().is_end_stream(), data_len)
328                     }
329                 }
330                 _ => Some(Poll::Ready(Err(std::io::Error::new(
331                     std::io::ErrorKind::Other,
332                     HttpError::from(H2Error::ConnectionError(ErrorCode::ProtocolError)),
333                 )))),
334             };
335         }
336         None
337     }
338 }
339 
340 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdownnull341     fn shutdown(&self) {
342         self.handle.io_shutdown.store(true, Ordering::Release);
343     }
344 
is_stream_closablenull345     fn is_stream_closable(&self) -> bool {
346         self.is_closed
347     }
348 }
349 
350 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_readnull351     fn poll_read(
352         self: Pin<&mut Self>,
353         cx: &mut Context<'_>,
354         buf: &mut ReadBuf<'_>,
355     ) -> Poll<std::io::Result<()>> {
356         let text_io = self.get_mut();
357         let mut buf = HttpReadBuf { buf };
358 
359         if buf.remaining() == 0 || text_io.is_closed {
360             return Poll::Ready(Ok(()));
361         }
362         while buf.remaining() != 0 {
363             if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
364                 return result;
365             }
366 
367             let poll_result = text_io
368                 .handle
369                 .receiver
370                 .poll_recv(cx)
371                 .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?;
372 
373             if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
374                 return result;
375             }
376         }
377         Poll::Ready(Ok(()))
378     }
379 }
380 
381 #[cfg(feature = "http2")]
382 #[cfg(test)]
383 mod ut_http2 {
384     use ylong_http::body::TextBody;
385     use ylong_http::h2::Payload;
386     use ylong_http::request::RequestBuilder;
387 
388     use crate::async_impl::conn::http2::build_headers_payload;
389 
390     macro_rules! build_request {
391         (
392             Request: {
393                 Method: $method: expr,
394                 Uri: $uri:expr,
395                 Version: $version: expr,
396                 $(
397                     Header: $req_n: expr, $req_v: expr,
398                 )*
399                 Body: $req_body: expr,
400             }
401         ) => {
402             RequestBuilder::new()
403                 .method($method)
404                 .url($uri)
405                 .version($version)
406                 $(.header($req_n, $req_v))*
407                 .body(TextBody::from_bytes($req_body.as_bytes()))
408                 .expect("Request build failed")
409         }
410     }
411 
412     #[test]
ut_http2_build_headers_payloadnull413     fn ut_http2_build_headers_payload() {
414         let request = build_request!(
415             Request: {
416             Method: "GET",
417             Uri: "http://127.0.0.1:0/data",
418             Version: "HTTP/2.0",
419             Header: "te", "trailers",
420             Header: "host", "127.0.0.1:0",
421             Body: "Hi",
422         }
423         );
424         let (flag, _) = build_headers_payload(request.part().clone(), false).unwrap();
425         assert_eq!(flag.bits(), 0x4);
426         let (flag, payload) = build_headers_payload(request.part().clone(), true).unwrap();
427         assert_eq!(flag.bits(), 0x5);
428         if let Payload::Headers(headers) = payload {
429             let (pseudo, _headers) = headers.parts();
430             assert_eq!(pseudo.status(), None);
431             assert_eq!(pseudo.scheme().unwrap(), "http");
432             assert_eq!(pseudo.method().unwrap(), "GET");
433             assert_eq!(pseudo.authority().unwrap(), "127.0.0.1:0");
434             assert_eq!(pseudo.path().unwrap(), "/data")
435         } else {
436             panic!("Unexpected frame type")
437         }
438     }
439 
440     /// UT for ensure that the response body(data frame) can read ends normally.
441     ///
442     /// # Brief
443     /// 1. Creates three data frames, one greater than buf, one less than buf,
444     ///    and the last one equal to and finished with buf.
445     /// 2. The response body data is read from TextIo using a buf of 10 bytes.
446     /// 3. The body is all read, and the size is the same as the default.
447     /// 5. Checks that result.
448     #[cfg(feature = "ylong_base")]
449     #[test]
ut_http2_body_poll_readnull450     fn ut_http2_body_poll_read() {
451         use std::pin::Pin;
452         use std::sync::atomic::AtomicBool;
453         use std::sync::Arc;
454 
455         use ylong_http::h2::{Data, Frame, FrameFlags};
456         use ylong_runtime::futures::poll_fn;
457         use ylong_runtime::io::{AsyncRead, ReadBuf};
458 
459         use crate::async_impl::conn::http2::TextIo;
460         use crate::util::dispatcher::http2::Http2Conn;
461 
462         let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20);
463         let (req_tx, _req_rx) = crate::runtime::unbounded_channel();
464         let shutdown = Arc::new(AtomicBool::new(false));
465         let mut conn: Http2Conn<()> = Http2Conn::new(20, shutdown, req_tx);
466         conn.receiver.set_receiver(resp_rx);
467         let mut text_io = TextIo::new(conn);
468         let data_1 = Frame::new(
469             1,
470             FrameFlags::new(0),
471             Payload::Data(Data::new(vec![b'a'; 128])),
472         );
473         let data_2 = Frame::new(
474             1,
475             FrameFlags::new(0),
476             Payload::Data(Data::new(vec![b'a'; 2])),
477         );
478         let data_3 = Frame::new(
479             1,
480             FrameFlags::new(1),
481             Payload::Data(Data::new(vec![b'a'; 10])),
482         );
483 
484         ylong_runtime::block_on(async {
485             let _ = resp_tx
486                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_1))
487                 .await;
488             let _ = resp_tx
489                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_2))
490                 .await;
491             let _ = resp_tx
492                 .send(crate::util::dispatcher::http2::RespMessage::Output(data_3))
493                 .await;
494         });
495 
496         ylong_runtime::block_on(async {
497             let mut buf = [0_u8; 10];
498             let mut output_vec = vec![];
499 
500             let mut size = buf.len();
501             // `output_vec < 1024` in order to be able to exit normally in case of an
502             // exception.
503             while size != 0 && output_vec.len() < 1024 {
504                 let mut buffer = ReadBuf::new(buf.as_mut_slice());
505                 poll_fn(|cx| Pin::new(&mut text_io).poll_read(cx, &mut buffer))
506                     .await
507                     .unwrap();
508                 size = buffer.filled_len();
509                 output_vec.extend_from_slice(&buf[..size]);
510             }
511             assert_eq!(output_vec.len(), 140);
512         })
513     }
514 }
515