1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::cmp::min;
15 use std::ops::Deref;
16 use std::pin::Pin;
17 use std::sync::atomic::Ordering;
18 use std::task::{Context, Poll};
19
20 use ylong_http::error::HttpError;
21 use ylong_http::h3::{
22 Frame, H3Error, H3ErrorCode, Headers, Parts, Payload, PseudoHeaders, HEADERS_FRAME_TYPE,
23 };
24 use ylong_http::request::uri::Scheme;
25 use ylong_http::request::RequestPart;
26 use ylong_http::response::status::StatusCode;
27 use ylong_http::response::ResponsePart;
28 use ylong_runtime::io::ReadBuf;
29
30 use crate::async_impl::conn::StreamData;
31 use crate::async_impl::request::Message;
32 use crate::async_impl::{HttpBody, Response};
33 use crate::runtime::AsyncRead;
34 use crate::util::data_ref::BodyDataRef;
35 use crate::util::dispatcher::http3::{DispatchErrorKind, Http3Conn, RequestWrapper, RespMessage};
36 use crate::util::normalizer::BodyLengthParser;
37 use crate::{ErrorKind, HttpClientError};
38
39 pub(crate) async fn request<S>(
40 mut conn: Http3Conn<S>,
41 mut message: Message,
42 ) -> Result<Response, HttpClientError>
43 where
44 S: Sync + Send + Unpin + 'static,
45 {
46 message
47 .interceptor
48 .intercept_request(message.request.ref_mut())?;
49 let part = message.request.ref_mut().part().clone();
50
51 // TODO Implement trailer.
52 let headers = build_headers_frame(part)
53 .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?;
54 let data = BodyDataRef::new(message.request.clone());
55 let stream = RequestWrapper {
56 header: headers,
57 data,
58 };
59 conn.send_frame_to_reader(stream)?;
60 let frame = conn.recv_resp().await?;
61 frame_2_response(conn, frame, message)
62 }
63
64 pub(crate) fn build_headers_frame(mut part: RequestPart) -> Result<Frame, HttpError> {
65 // todo: check rfc to see if any headers should be removed
66 let pseudo = build_pseudo_headers(&mut part)?;
67 let mut header_part = Parts::new();
68 header_part.set_header_lines(part.headers);
69 header_part.set_pseudo(pseudo);
70 let headers_payload = Headers::new(header_part);
71
72 Ok(Frame::new(
73 HEADERS_FRAME_TYPE,
74 Payload::Headers(headers_payload),
75 ))
76 }
77
78 // todo: error if headers not enough, should meet rfc
build_pseudo_headersnull79 fn build_pseudo_headers(request_part: &mut RequestPart) -> Result<PseudoHeaders, HttpError> {
80 let mut pseudo = PseudoHeaders::default();
81 match request_part.uri.scheme() {
82 Some(scheme) => {
83 pseudo.set_scheme(Some(String::from(scheme.as_str())));
84 }
85 None => pseudo.set_scheme(Some(String::from(Scheme::HTTPS.as_str()))),
86 }
87 pseudo.set_method(Some(String::from(request_part.method.as_str())));
88 pseudo.set_path(
89 request_part
90 .uri
91 .path_and_query()
92 .or_else(|| Some(String::from("/"))),
93 );
94 let host = request_part
95 .headers
96 .remove("host")
97 .and_then(|auth| auth.to_string().ok());
98 pseudo.set_authority(host);
99 Ok(pseudo)
100 }
101
frame_2_responsenull102 fn frame_2_response<S>(
103 conn: Http3Conn<S>,
104 headers_frame: Frame,
105 mut message: Message,
106 ) -> Result<Response, HttpClientError>
107 where
108 S: Sync + Send + Unpin + 'static,
109 {
110 let part = match headers_frame.payload() {
111 Payload::Headers(headers) => {
112 let part = headers.get_part();
113 let (pseudo, fields) = part.parts();
114 let status_code = match pseudo.status() {
115 Some(status) => StatusCode::from_bytes(status.as_bytes())
116 .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?,
117 None => {
118 return Err(HttpClientError::from_str(
119 ErrorKind::Request,
120 "status code not found",
121 ));
122 }
123 };
124 ResponsePart {
125 version: ylong_http::version::Version::HTTP3,
126 status: status_code,
127 headers: fields.clone(),
128 }
129 }
130 Payload::PushPromise(_) => {
131 todo!();
132 }
133 _ => {
134 return Err(HttpClientError::from_str(ErrorKind::Request, "bad frame"));
135 }
136 };
137
138 let data_io = TextIo::new(conn);
139 let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() {
140 Ok(length) => length,
141 Err(e) => {
142 return Err(e);
143 }
144 };
145 let body = HttpBody::new(message.interceptor, length, Box::new(data_io), &[0u8; 0])?;
146
147 Ok(Response::new(
148 ylong_http::response::Response::from_raw_parts(part, body),
149 ))
150 }
151
152 struct TextIo<S> {
153 pub(crate) handle: Http3Conn<S>,
154 pub(crate) offset: usize,
155 pub(crate) remain: Option<Frame>,
156 pub(crate) is_closed: bool,
157 }
158
159 struct HttpReadBuf<'a, 'b> {
160 buf: &'a mut ReadBuf<'b>,
161 }
162
163 impl<'a, 'b> HttpReadBuf<'a, 'b> {
164 pub(crate) fn append_slice(&mut self, buf: &[u8]) {
165 #[cfg(feature = "ylong_base")]
166 self.buf.append(buf);
167
168 #[cfg(feature = "tokio_base")]
169 self.buf.put_slice(buf);
170 }
171 }
172
173 impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> {
174 type Target = ReadBuf<'b>;
175
derefnull176 fn deref(&self) -> &Self::Target {
177 self.buf
178 }
179 }
180
181 impl<S> TextIo<S>
182 where
183 S: Sync + Send + Unpin + 'static,
184 {
185 pub(crate) fn new(handle: Http3Conn<S>) -> Self {
186 Self {
187 handle,
188 offset: 0,
189 remain: None,
190 is_closed: false,
191 }
192 }
193
match_channel_messagenull194 fn match_channel_message(
195 poll_result: Poll<RespMessage>,
196 text_io: &mut TextIo<S>,
197 buf: &mut HttpReadBuf,
198 ) -> Option<Poll<std::io::Result<()>>> {
199 match poll_result {
200 Poll::Ready(RespMessage::Output(frame)) => match frame.payload() {
201 Payload::Headers(_) => {
202 text_io.remain = Some(frame);
203 text_io.offset = 0;
204 Some(Poll::Ready(Ok(())))
205 }
206 Payload::Data(data) => {
207 let data = data.data();
208 let unfilled_len = buf.remaining();
209 let data_len = data.len();
210 let fill_len = min(data_len, unfilled_len);
211 if unfilled_len < data_len {
212 buf.append_slice(&data[..fill_len]);
213 text_io.offset += fill_len;
214 text_io.remain = Some(frame);
215 Some(Poll::Ready(Ok(())))
216 } else {
217 buf.append_slice(&data[..fill_len]);
218 Self::end_read(text_io, data_len)
219 }
220 }
221 _ => Some(Poll::Ready(Err(std::io::Error::new(
222 std::io::ErrorKind::Other,
223 HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
224 )))),
225 },
226 Poll::Ready(RespMessage::OutputExit(e)) => match e {
227 DispatchErrorKind::H3(H3Error::Connection(H3ErrorCode::H3NoError))
228 | DispatchErrorKind::StreamFinished => {
229 text_io.is_closed = true;
230 Some(Poll::Ready(Ok(())))
231 }
232 _ => Some(Poll::Ready(Err(std::io::Error::new(
233 std::io::ErrorKind::Other,
234 HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
235 )))),
236 },
237 Poll::Pending => Some(Poll::Pending),
238 }
239 }
240
end_readnull241 fn end_read(text_io: &mut TextIo<S>, data_len: usize) -> Option<Poll<std::io::Result<()>>> {
242 text_io.offset = 0;
243 text_io.remain = None;
244 if data_len == 0 {
245 // no data read and is not end stream.
246 None
247 } else {
248 Some(Poll::Ready(Ok(())))
249 }
250 }
251
read_remaining_datanull252 fn read_remaining_data(
253 text_io: &mut TextIo<S>,
254 buf: &mut HttpReadBuf,
255 ) -> Option<Poll<std::io::Result<()>>> {
256 if let Some(frame) = &text_io.remain {
257 return match frame.payload() {
258 Payload::Headers(_) => Some(Poll::Ready(Ok(()))),
259 Payload::Data(data) => {
260 let data = data.data();
261 let unfilled_len = buf.remaining();
262 let data_len = data.len() - text_io.offset;
263 let fill_len = min(unfilled_len, data_len);
264 // The peripheral function already ensures that the remaing of buf will not be
265 // 0.
266 if unfilled_len < data_len {
267 buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
268 text_io.offset += fill_len;
269 Some(Poll::Ready(Ok(())))
270 } else {
271 buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]);
272 Self::end_read(text_io, data_len)
273 }
274 }
275 _ => Some(Poll::Ready(Err(std::io::Error::new(
276 std::io::ErrorKind::Other,
277 HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)),
278 )))),
279 };
280 }
281 None
282 }
283 }
284
285 impl<S: Sync + Send + Unpin + 'static> StreamData for TextIo<S> {
shutdownnull286 fn shutdown(&self) {
287 self.handle.io_shutdown.store(true, Ordering::Relaxed);
288 }
289
is_stream_closablenull290 fn is_stream_closable(&self) -> bool {
291 self.is_closed
292 }
293 }
294
295 impl<S: Sync + Send + Unpin + 'static> AsyncRead for TextIo<S> {
poll_readnull296 fn poll_read(
297 self: Pin<&mut Self>,
298 cx: &mut Context<'_>,
299 buf: &mut ReadBuf<'_>,
300 ) -> Poll<std::io::Result<()>> {
301 let text_io = self.get_mut();
302 let mut buf = HttpReadBuf { buf };
303
304 if buf.remaining() == 0 || text_io.is_closed {
305 return Poll::Ready(Ok(()));
306 }
307 while buf.remaining() != 0 {
308 if let Some(result) = Self::read_remaining_data(text_io, &mut buf) {
309 return result;
310 }
311
312 let poll_result = text_io
313 .handle
314 .resp_receiver
315 .poll_recv(cx)
316 .map_err(|_e| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?;
317
318 if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) {
319 return result;
320 }
321 }
322 Poll::Ready(Ok(()))
323 }
324 }
325