1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::collections::{HashMap, HashSet, VecDeque}; 15 use std::future::Future; 16 use std::pin::Pin; 17 use std::sync::atomic::{AtomicU64, Ordering}; 18 use std::task::{Context, Poll}; 19 20 use ylong_http::h3::{Data, Frame, H3Error, H3ErrorCode, Payload, DATA_FRAME_TYPE}; 21 22 use crate::runtime::{BoundedSender, SendError}; 23 use crate::util::data_ref::BodyDataRef; 24 use crate::util::dispatcher::http3::{DispatchErrorKind, RespMessage}; 25 26 pub(crate) type OutputSendFut = 27 Pin<Box<dyn Future<Output = Result<(), SendError<RespMessage>>> + Send + Sync>>; 28 29 const HTTP3_FIRST_BIDI_STREAM_ID: u64 = 0u64; 30 const HTTP3_FIRST_UNI_STREAM_ID: u64 = 2u64; 31 const HTTP3_MAX_STREAM_ID: u64 = (1u64 << 62) - 1; 32 const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 100; 33 34 #[derive(PartialEq, Clone)] 35 pub(crate) enum H3StreamState { 36 Sending, 37 HeadersReceived, 38 BodyReceived, 39 TrailerReceived, 40 Shutdown, 41 } 42 43 #[derive(PartialEq, Clone)] 44 pub(crate) enum QUICStreamType { 45 ClientInitialBidirectional, 46 ServerInitialBidirectional, 47 ClientInitialUnidirectional, 48 ServerInitialUnidirectional, 49 } 50 51 impl QUICStreamType { 52 pub(crate) fn from(id: u64) -> Self { 53 match id % 4 { 54 0 => QUICStreamType::ClientInitialBidirectional, 55 1 => QUICStreamType::ServerInitialBidirectional, 56 2 => QUICStreamType::ClientInitialUnidirectional, 57 _ => QUICStreamType::ServerInitialUnidirectional, 58 } 59 } 60 } 61 62 // Unidirectional Streams 63 pub(crate) struct BidirectionalStream { 64 pub(crate) state: H3StreamState, 65 pub(crate) frame_tx: BoundedSender<RespMessage>, 66 pub(crate) header: Option<Frame>, 67 pub(crate) data: BodyDataRef, 68 pub(crate) pending_message: VecDeque<RespMessage>, 69 pub(crate) encoding: bool, 70 pub(crate) curr_message: Option<OutputSendFut>, 71 } 72 73 impl BidirectionalStream { newnull74 fn new(frame_tx: BoundedSender<RespMessage>, header: Frame, data: BodyDataRef) -> Self { 75 Self { 76 state: H3StreamState::Sending, 77 frame_tx, 78 header: Some(header), 79 data, 80 pending_message: VecDeque::new(), 81 encoding: false, 82 curr_message: None, 83 } 84 } 85 transmit_messagenull86 fn transmit_message( 87 &mut self, 88 cx: &mut Context<'_>, 89 message: RespMessage, 90 ) -> Poll<Result<(), DispatchErrorKind>> { 91 let mut task = { 92 let sender = self.frame_tx.clone(); 93 let ft = async move { sender.send(message).await }; 94 Box::pin(ft) 95 }; 96 97 match task.as_mut().poll(cx) { 98 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), 99 // The current coroutine sending the request exited prematurely. 100 Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), 101 Poll::Pending => { 102 self.curr_message = Some(task); 103 Poll::Pending 104 } 105 } 106 } 107 } 108 109 pub(crate) struct Streams { 110 bidirectional_stream: HashMap<u64, BidirectionalStream>, 111 control_stream_id: Option<u64>, 112 peer_control_stream_id: Option<u64>, 113 qpack_encode_stream_id: Option<u64>, 114 qpack_decode_stream_id: Option<u64>, 115 peer_qpack_encode_stream_id: Option<u64>, 116 peer_qpack_decode_stream_id: Option<u64>, 117 // unused now 118 goaway_id: Option<u64>, 119 peer_goaway_id: Option<u64>, 120 // meet the sending conditions, waiting for sending 121 pending_send: VecDeque<u64>, 122 // cannot recv cause of stream blocks 123 pending_recv: HashSet<u64>, 124 // stream resumes and should decode again 125 resume_recv: VecDeque<u64>, 126 // too many working streams, pending for concurrency 127 pending_concurrency: VecDeque<u64>, 128 // cannot recv cause of channel blocked 129 pending_channel: HashSet<u64>, 130 working_stream_num: u32, 131 max_stream_concurrency: u32, 132 next_uni_stream_id: AtomicU64, 133 next_bidi_stream_id: AtomicU64, 134 } 135 136 impl Streams { 137 pub(crate) fn new() -> Self { 138 Self { 139 bidirectional_stream: HashMap::new(), 140 control_stream_id: None, 141 peer_control_stream_id: None, 142 qpack_encode_stream_id: None, 143 qpack_decode_stream_id: None, 144 peer_qpack_encode_stream_id: None, 145 peer_qpack_decode_stream_id: None, 146 goaway_id: None, 147 peer_goaway_id: None, 148 pending_send: VecDeque::new(), 149 pending_recv: HashSet::new(), 150 resume_recv: VecDeque::new(), 151 pending_concurrency: VecDeque::new(), 152 pending_channel: HashSet::new(), 153 working_stream_num: 0, 154 max_stream_concurrency: DEFAULT_MAX_CONCURRENT_STREAMS, 155 next_uni_stream_id: AtomicU64::new(HTTP3_FIRST_UNI_STREAM_ID), 156 next_bidi_stream_id: AtomicU64::new(HTTP3_FIRST_BIDI_STREAM_ID), 157 } 158 } 159 160 pub(crate) fn new_unidirectional_stream( 161 &mut self, 162 header: Frame, 163 data: BodyDataRef, 164 rx: BoundedSender<RespMessage>, 165 ) -> Result<(), DispatchErrorKind> { 166 let id = 167 self.get_next_bidi_stream_id() 168 .ok_or(DispatchErrorKind::H3(H3Error::Connection( 169 H3ErrorCode::H3GeneralProtocolError, 170 )))?; 171 self.bidirectional_stream 172 .insert(id, BidirectionalStream::new(rx, header, data)); 173 if self.reach_max_concurrency() { 174 self.push_back_pending_concurrency(id); 175 } else { 176 self.push_back_pending_send(id); 177 self.increase_current_concurrency(); 178 } 179 Ok(()) 180 } 181 182 pub(crate) fn send_frame( 183 &mut self, 184 cx: &mut Context<'_>, 185 id: u64, 186 frame: Frame, 187 ) -> Result<(), DispatchErrorKind> { 188 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 189 match stream.state { 190 H3StreamState::Sending => { 191 if let Payload::Headers(_) = frame.payload() { 192 stream.state = H3StreamState::HeadersReceived; 193 } else { 194 return Err(DispatchErrorKind::H3(H3Error::Connection( 195 H3ErrorCode::H3FrameUnexpected, 196 ))); 197 } 198 } 199 H3StreamState::HeadersReceived => { 200 if let Payload::Headers(_) = frame.payload() { 201 return Err(DispatchErrorKind::H3(H3Error::Connection( 202 H3ErrorCode::H3FrameUnexpected, 203 ))); 204 } else { 205 stream.state = H3StreamState::BodyReceived; 206 } 207 } 208 H3StreamState::BodyReceived => { 209 if let Payload::Headers(_) = frame.payload() { 210 stream.state = H3StreamState::TrailerReceived; 211 } 212 } 213 H3StreamState::TrailerReceived => { 214 return Err(DispatchErrorKind::H3(H3Error::Connection( 215 H3ErrorCode::H3FrameUnexpected, 216 ))); 217 } 218 H3StreamState::Shutdown => { 219 // stream has been shutdown, drop frame 220 return Ok(()); 221 } 222 } 223 if stream.curr_message.is_some() { 224 stream.pending_message.push_back(RespMessage::Output(frame)); 225 return Ok(()); 226 } 227 if let Poll::Ready(ret) = stream.transmit_message(cx, RespMessage::Output(frame)) { 228 ret 229 } else { 230 self.stream_pend_channel(id); 231 Ok(()) 232 } 233 } else { 234 Err(DispatchErrorKind::ChannelClosed) 235 } 236 } 237 238 pub(crate) fn send_error( 239 &mut self, 240 cx: &mut Context<'_>, 241 id: u64, 242 error: DispatchErrorKind, 243 ) -> Result<(), DispatchErrorKind> { 244 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 245 stream.pending_message.clear(); 246 if let Poll::Ready(ret) = stream.transmit_message(cx, RespMessage::OutputExit(error)) { 247 ret 248 } else { 249 self.stream_pend_channel(id); 250 Ok(()) 251 } 252 } else { 253 Err(DispatchErrorKind::ChannelClosed) 254 } 255 } 256 257 pub(crate) fn control_stream_id(&mut self) -> Option<u64> { 258 if self.control_stream_id.is_some() { 259 self.control_stream_id 260 } else { 261 self.control_stream_id = self.get_next_uni_stream_id(); 262 self.control_stream_id 263 } 264 } 265 266 pub(crate) fn qpack_decode_stream_id(&mut self) -> Option<u64> { 267 if self.qpack_decode_stream_id.is_some() { 268 self.qpack_decode_stream_id 269 } else { 270 self.qpack_decode_stream_id = self.get_next_uni_stream_id(); 271 self.qpack_decode_stream_id 272 } 273 } 274 275 pub(crate) fn qpack_encode_stream_id(&mut self) -> Option<u64> { 276 if self.qpack_encode_stream_id.is_some() { 277 self.qpack_encode_stream_id 278 } else { 279 self.qpack_encode_stream_id = self.get_next_uni_stream_id(); 280 self.qpack_encode_stream_id 281 } 282 } 283 284 pub(crate) fn peer_qpack_encode_stream_id(&self) -> Option<u64> { 285 self.peer_qpack_encode_stream_id 286 } 287 288 pub(crate) fn peer_goaway_id(&self) -> Option<u64> { 289 self.peer_goaway_id 290 } 291 292 #[allow(unused)] 293 pub(crate) fn goaway_id(&self) -> Option<u64> { 294 self.goaway_id 295 } 296 297 pub(crate) fn peer_control_stream_id(&self) -> Option<u64> { 298 self.peer_control_stream_id 299 } 300 301 pub(crate) fn peer_qpack_decode_stream_id(&self) -> Option<u64> { 302 self.peer_qpack_decode_stream_id 303 } 304 305 pub(crate) fn set_peer_qpack_encode_stream_id( 306 &mut self, 307 id: u64, 308 ) -> Result<(), DispatchErrorKind> { 309 if let Some(old_id) = self.peer_qpack_encode_stream_id { 310 if old_id != id { 311 return Err(DispatchErrorKind::H3(H3Error::Connection( 312 H3ErrorCode::H3StreamCreationError, 313 ))); 314 } 315 } else { 316 self.peer_qpack_encode_stream_id = Some(id); 317 } 318 Ok(()) 319 } 320 321 pub(crate) fn set_peer_control_stream_id(&mut self, id: u64) -> Result<(), DispatchErrorKind> { 322 if let Some(old_id) = self.peer_control_stream_id { 323 if old_id != id { 324 return Err(DispatchErrorKind::H3(H3Error::Connection( 325 H3ErrorCode::H3StreamCreationError, 326 ))); 327 } 328 } else { 329 self.peer_control_stream_id = Some(id); 330 } 331 Ok(()) 332 } 333 334 pub(crate) fn set_peer_qpack_decode_stream_id( 335 &mut self, 336 id: u64, 337 ) -> Result<(), DispatchErrorKind> { 338 if let Some(old_id) = self.peer_qpack_decode_stream_id { 339 if old_id != id { 340 return Err(DispatchErrorKind::H3(H3Error::Connection( 341 H3ErrorCode::H3StreamCreationError, 342 ))); 343 } 344 } else { 345 self.peer_qpack_decode_stream_id = Some(id); 346 } 347 Ok(()) 348 } 349 350 #[allow(unused)] 351 pub(crate) fn set_goaway_id(&mut self, id: u64) -> Result<(), DispatchErrorKind> { 352 if let Some(old_goaway_id) = self.goaway_id { 353 if id > old_goaway_id { 354 return Err(DispatchErrorKind::H3(H3Error::Connection( 355 H3ErrorCode::H3InternalError, 356 ))); 357 } 358 } 359 self.goaway_id = Some(id); 360 Ok(()) 361 } 362 363 pub(crate) fn get_header(&mut self, id: u64) -> Result<Option<Frame>, DispatchErrorKind> { 364 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 365 Ok(stream.header.take()) 366 } else { 367 Err(DispatchErrorKind::H3(H3Error::Connection( 368 H3ErrorCode::H3InternalError, 369 ))) 370 } 371 } 372 373 pub(crate) fn frame_acceptable(&mut self, id: u64) -> bool { 374 !self.is_stream_recv_pending(id) && !self.is_stream_channel_pending(id) 375 } 376 377 pub(crate) fn decrease_current_concurrency(&mut self) { 378 self.working_stream_num -= 1; 379 } 380 381 pub(crate) fn increase_current_concurrency(&mut self) { 382 self.working_stream_num += 1; 383 } 384 385 pub(crate) fn current_concurrency(&mut self) -> u32 { 386 self.working_stream_num 387 } 388 389 pub(crate) fn reach_max_concurrency(&mut self) -> bool { 390 self.working_stream_num >= self.max_stream_concurrency 391 } 392 393 pub(crate) fn push_back_pending_send(&mut self, id: u64) { 394 self.pending_send.push_back(id); 395 } 396 397 pub(crate) fn next_stream(&mut self) -> Option<u64> { 398 self.pending_send.pop_front() 399 } 400 401 pub(crate) fn pending_stream_len(&mut self) -> u64 { 402 self.pending_send.len() as u64 403 } 404 405 pub(crate) fn push_back_pending_concurrency(&mut self, id: u64) { 406 self.pending_concurrency.push_back(id); 407 } 408 409 pub(crate) fn pop_front_pending_concurrency(&mut self) -> Option<u64> { 410 self.pending_concurrency.pop_front() 411 } 412 413 pub(crate) fn stream_pend_channel(&mut self, id: u64) { 414 self.pending_channel.insert(id); 415 } 416 417 pub(crate) fn is_stream_channel_pending(&self, id: u64) -> bool { 418 self.pending_channel.contains(&id) 419 } 420 421 pub(crate) fn try_consume_pending_concurrency(&mut self) { 422 while !self.reach_max_concurrency() { 423 match self.pop_front_pending_concurrency() { 424 Some(id) => { 425 self.push_back_pending_send(id); 426 self.increase_current_concurrency(); 427 } 428 None => { 429 return; 430 } 431 } 432 } 433 } 434 435 pub(crate) fn get_next_uni_stream_id(&self) -> Option<u64> { 436 let id = self.next_uni_stream_id.fetch_add(4, Ordering::Relaxed); 437 if id > HTTP3_MAX_STREAM_ID { 438 None 439 } else { 440 Some(id) 441 } 442 } 443 444 pub(crate) fn get_next_bidi_stream_id(&self) -> Option<u64> { 445 let id = self.next_bidi_stream_id.fetch_add(4, Ordering::Relaxed); 446 if id > HTTP3_MAX_STREAM_ID { 447 None 448 } else { 449 Some(id) 450 } 451 } 452 453 pub(crate) fn pend_stream_recv(&mut self, id: u64) { 454 self.pending_recv.insert(id); 455 } 456 457 pub(crate) fn resume_stream_recv(&mut self, id: u64) { 458 self.pending_recv.remove(&id); 459 self.resume_recv.push_back(id); 460 } 461 462 pub(crate) fn is_stream_recv_pending(&self, id: u64) -> bool { 463 self.pending_recv.contains(&id) 464 } 465 466 pub(crate) fn get_resume_stream_id(&mut self) -> Option<u64> { 467 self.resume_recv.pop_front() 468 } 469 470 pub(crate) fn poll_sized_data( 471 &mut self, 472 cx: &mut Context<'_>, 473 id: u64, 474 buf: &mut [u8], 475 ) -> Result<DataReadState, DispatchErrorKind> { 476 let stream = self 477 .bidirectional_stream 478 .get_mut(&id) 479 .ok_or(DispatchErrorKind::H3(H3Error::Connection( 480 H3ErrorCode::H3InternalError, 481 )))?; 482 483 if stream.state == H3StreamState::Shutdown { 484 return Ok(DataReadState::Closed); 485 } 486 487 match stream.data.poll_read(cx, buf) { 488 Poll::Ready(Some(size)) => { 489 if size > 0 { 490 let data_vec = Vec::from(&buf[..size]); 491 Ok(DataReadState::Ready(Box::new(Frame::new( 492 DATA_FRAME_TYPE, 493 Payload::Data(Data::new(data_vec)), 494 )))) 495 } else { 496 Ok(DataReadState::Finish) 497 } 498 } 499 Poll::Ready(None) => Err(DispatchErrorKind::H3(H3Error::Connection( 500 H3ErrorCode::H3InternalError, 501 ))), 502 Poll::Pending => { 503 self.push_back_pending_send(id); 504 Ok(DataReadState::Pending) 505 } 506 } 507 } 508 509 pub(crate) fn shutdown_stream(&mut self, cx: &mut Context<'_>, id: u64, err: &H3ErrorCode) { 510 let Some(stream) = self.bidirectional_stream.get_mut(&id) else { 511 return; 512 }; 513 if stream 514 .transmit_message( 515 cx, 516 RespMessage::OutputExit(DispatchErrorKind::H3(H3Error::Stream(id, *err))), 517 ) 518 .is_pending() 519 { 520 self.stream_pend_channel(id); 521 } 522 self.decrease_current_concurrency(); 523 // stream.header = None; 524 // stream.pending_frame.clear(); 525 // stream.data.clear(); 526 // stream.state = H3StreamState::Shutdown; 527 } 528 529 pub(crate) fn goaway( 530 &mut self, 531 cx: &mut Context<'_>, 532 goaway_id: u64, 533 ) -> Result<(), DispatchErrorKind> { 534 if let Some(old_goaway_id) = self.peer_goaway_id() { 535 if goaway_id > old_goaway_id { 536 return Err(DispatchErrorKind::H3(H3Error::Connection( 537 H3ErrorCode::H3IdError, 538 ))); 539 } 540 } 541 if QUICStreamType::from(goaway_id) != QUICStreamType::ClientInitialBidirectional { 542 return Err(DispatchErrorKind::H3(H3Error::Connection( 543 H3ErrorCode::H3IdError, 544 ))); 545 } 546 self.goaway_id = Some(goaway_id); 547 let mut pending_channels = Vec::new(); 548 for (id, stream) in self.bidirectional_stream.iter_mut() { 549 if id > &goaway_id { 550 stream.state = H3StreamState::Shutdown; 551 stream.header = None; 552 stream.pending_message.clear(); 553 stream.data.clear(); 554 if stream 555 .transmit_message( 556 cx, 557 RespMessage::OutputExit(DispatchErrorKind::GoawayReceived), 558 ) 559 .is_pending() 560 { 561 pending_channels.push(*id); 562 } 563 } 564 } 565 for id in pending_channels { 566 self.stream_pend_channel(id); 567 } 568 Ok(()) 569 } 570 571 pub(crate) fn shutdown(&mut self, cx: &mut Context<'_>, err: &DispatchErrorKind) { 572 let mut pending_channels = Vec::new(); 573 for (id, stream) in self.bidirectional_stream.iter_mut() { 574 stream.state = H3StreamState::Shutdown; 575 stream.header = None; 576 stream.pending_message.clear(); 577 stream.data.clear(); 578 if stream 579 .transmit_message(cx, RespMessage::OutputExit(err.clone())) 580 .is_pending() 581 { 582 pending_channels.push(*id); 583 } 584 } 585 for id in pending_channels { 586 self.stream_pend_channel(id); 587 } 588 } 589 590 pub(crate) fn set_encoding( 591 &mut self, 592 id: u64, 593 encoding: bool, 594 ) -> Result<(), DispatchErrorKind> { 595 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 596 stream.encoding = encoding; 597 Ok(()) 598 } else { 599 Err(DispatchErrorKind::H3(H3Error::Connection( 600 H3ErrorCode::H3InternalError, 601 ))) 602 } 603 } 604 605 pub(crate) fn encoding(&mut self, id: u64) -> Result<bool, DispatchErrorKind> { 606 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 607 Ok(stream.encoding) 608 } else { 609 Err(DispatchErrorKind::H3(H3Error::Connection( 610 H3ErrorCode::H3InternalError, 611 ))) 612 } 613 } 614 615 pub(crate) fn finish_stream( 616 &mut self, 617 cx: &mut Context<'_>, 618 id: u64, 619 ) -> Result<(), DispatchErrorKind> { 620 if QUICStreamType::from(id) != QUICStreamType::ClientInitialBidirectional { 621 return if Some(id) == self.peer_control_stream_id() 622 || Some(id) == self.peer_qpack_encode_stream_id() 623 || Some(id) == self.peer_qpack_decode_stream_id() 624 { 625 Err(DispatchErrorKind::H3(H3Error::Connection( 626 H3ErrorCode::H3ClosedCriticalStream, 627 ))) 628 } else { 629 Ok(()) 630 }; 631 } 632 self.decrease_current_concurrency(); 633 if let Some(stream) = self.bidirectional_stream.get_mut(&id) { 634 stream.state = H3StreamState::Shutdown; 635 if stream.curr_message.is_none() { 636 if let Poll::Ready(ret) = stream.transmit_message( 637 cx, 638 RespMessage::OutputExit(DispatchErrorKind::StreamFinished), 639 ) { 640 ret 641 } else { 642 self.stream_pend_channel(id); 643 Ok(()) 644 } 645 } else { 646 stream 647 .pending_message 648 .push_back(RespMessage::OutputExit(DispatchErrorKind::StreamFinished)); 649 Ok(()) 650 } 651 } else { 652 Err(DispatchErrorKind::H3(H3Error::Connection( 653 H3ErrorCode::H3InternalError, 654 ))) 655 } 656 } 657 658 pub(crate) fn poll_blocked_message( 659 &mut self, 660 cx: &mut Context<'_>, 661 ) -> Poll<Result<(), DispatchErrorKind>> { 662 let mut new_set = HashSet::new(); 663 for id in &self.pending_channel { 664 let Some(stream) = self.bidirectional_stream.get_mut(id) else { 665 return Poll::Ready(Err(DispatchErrorKind::H3(H3Error::Connection( 666 H3ErrorCode::H3InternalError, 667 )))); 668 }; 669 if let Some(mut task) = stream.curr_message.take() { 670 match task.as_mut().poll(cx) { 671 Poll::Ready(Ok(_)) => {} 672 Poll::Ready(Err(_)) => { 673 // todo: shutdown 674 stream.state = H3StreamState::Shutdown; 675 } 676 Poll::Pending => { 677 stream.curr_message = Some(task); 678 new_set.insert(*id); 679 continue; 680 } 681 } 682 } 683 while let Some(message) = stream.pending_message.pop_front() { 684 match stream.transmit_message(cx, message) { 685 Poll::Ready(Ok(())) => {} 686 Poll::Pending => { 687 new_set.insert(*id); 688 break; 689 } 690 Poll::Ready(Err(_)) => { 691 stream.state = H3StreamState::Shutdown; 692 break; 693 } 694 } 695 } 696 } 697 self.pending_channel = new_set; 698 Poll::Pending 699 } 700 } 701 702 pub(crate) enum DataReadState { 703 Closed, 704 // Wait for poll_read or wait for window. 705 Pending, 706 Ready(Box<Frame>), 707 Finish, 708 } 709