1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 //! implement `split` fn for io, split it into `Reader` half and `Writer` half. 15 16 use std::io; 17 use std::io::IoSlice; 18 use std::pin::Pin; 19 use std::sync::Arc; 20 use std::task::{Context, Poll}; 21 22 use ylong_runtime::io::{AsyncRead, AsyncWrite, ReadBuf}; 23 use ylong_runtime::sync::{Mutex, MutexGuard}; 24 25 macro_rules! ready { 26 ($e:expr $(,)?) => { 27 match $e { 28 std::task::Poll::Ready(t) => t, 29 std::task::Poll::Pending => return std::task::Poll::Pending, 30 } 31 }; 32 } 33 34 pub(crate) struct Reader<T> { 35 inner: Arc<InnerLock<T>>, 36 } 37 38 pub(crate) struct Writer<T> { 39 inner: Arc<InnerLock<T>>, 40 } 41 42 struct InnerLock<T> { 43 stream: Mutex<T>, 44 is_write_vectored: bool, 45 } 46 47 struct StreamGuard<'a, T> { 48 inner: MutexGuard<'a, T>, 49 } 50 51 pub(crate) fn split<T>(stream: T) -> (Reader<T>, Writer<T>) 52 where 53 T: AsyncRead + AsyncWrite, 54 { 55 let is_write_vectored = stream.is_write_vectored(); 56 let inner = Arc::new(InnerLock { 57 stream: Mutex::new(stream), 58 is_write_vectored, 59 }); 60 61 let rd = Reader { 62 inner: inner.clone(), 63 }; 64 65 let wr = Writer { inner }; 66 67 (rd, wr) 68 } 69 70 impl<T: AsyncRead> AsyncRead for Reader<T> { poll_readnull71 fn poll_read( 72 self: Pin<&mut Self>, 73 cx: &mut Context<'_>, 74 buf: &mut ReadBuf<'_>, 75 ) -> Poll<io::Result<()>> { 76 let mut guard = ready!(self.inner.get_lock(cx)); 77 guard.stream().poll_read(cx, buf) 78 } 79 } 80 81 impl<T: AsyncWrite> AsyncWrite for Writer<T> { poll_writenull82 fn poll_write( 83 self: Pin<&mut Self>, 84 cx: &mut Context<'_>, 85 buf: &[u8], 86 ) -> Poll<Result<usize, io::Error>> { 87 let mut inner = ready!(self.inner.get_lock(cx)); 88 inner.stream().poll_write(cx, buf) 89 } 90 poll_write_vectorednull91 fn poll_write_vectored( 92 self: Pin<&mut Self>, 93 cx: &mut Context<'_>, 94 bufs: &[IoSlice<'_>], 95 ) -> Poll<std::io::Result<usize>> { 96 let mut inner = ready!(self.inner.get_lock(cx)); 97 inner.stream().poll_write_vectored(cx, bufs) 98 } 99 is_write_vectorednull100 fn is_write_vectored(&self) -> bool { 101 self.inner.is_write_vectored 102 } 103 poll_flushnull104 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 105 let mut inner = ready!(self.inner.get_lock(cx)); 106 inner.stream().poll_flush(cx) 107 } 108 poll_shutdownnull109 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { 110 let mut inner = ready!(self.inner.get_lock(cx)); 111 inner.stream().poll_shutdown(cx) 112 } 113 } 114 115 impl<'a, T> StreamGuard<'a, T> { streamnull116 fn stream(&mut self) -> Pin<&mut T> { 117 // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual 118 // exclusion. 119 unsafe { Pin::new_unchecked(&mut *self.inner) } 120 } 121 } 122 123 impl<T> InnerLock<T> { get_locknull124 fn get_lock(&self, cx: &mut Context<'_>) -> Poll<StreamGuard<T>> { 125 match self.stream.try_lock() { 126 Ok(guard) => Poll::Ready(StreamGuard { inner: guard }), 127 Err(_) => { 128 std::thread::yield_now(); 129 cx.waker().wake_by_ref(); 130 131 Poll::Pending 132 } 133 } 134 } 135 } 136