1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::io;
15 use std::io::{IoSlice, SeekFrom};
16 use std::pin::Pin;
17 use std::task::{Context, Poll};
18 
19 use crate::io::buffered::DEFAULT_BUF_SIZE;
20 use crate::io::{poll_ready, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
21 
22 /// This is an asynchronous version of [`std::io::BufWriter`]
23 ///
24 /// The `AsyncBufWriter<W>` struct adds buffering to any writer that implements
25 /// AsyncWrite. It is suitable to perform large, infrequent writes on the
26 /// underlying [`AsyncWrite`] object and maintains an in-memory buffer of the
27 /// results.
28 ///
29 /// When the `AsyncBufWriter<W>` is dropped, the contents inside its buffer will
30 /// be discarded. Creating multiple instances of `AsyncBufWriter<W>` on the same
31 /// [`AsyncWrite`] stream may cause data loss.
32 pub struct AsyncBufWriter<W> {
33     inner: W,
34     buf: Vec<u8>,
35     written: usize,
36 }
37 
38 impl<W: AsyncWrite> AsyncBufWriter<W> {
39     /// Creates a new `AsyncBufWriter<W>` with a default buffer capacity.
40     /// The default buffer capacity is 8 KB, which is the same as
41     /// [`std::io::BufWriter`]
42     ///
43     /// # Examples
44     ///
45     /// ```no run
46     /// use ylong_runtime::fs::File;
47     ///
48     /// async fn main() -> std::io::Result<()> {
49     ///     use ylong_runtime::io::AsyncBufWriter;
50     ///     let f = File::open("test.txt").await?;
51     ///     let reader = AsyncBufWriter::new(f);
52     ///     Ok(())
53     /// }
54     /// ```
newnull55     pub fn new(inner: W) -> AsyncBufWriter<W> {
56         AsyncBufWriter::with_capacity(DEFAULT_BUF_SIZE, inner)
57     }
58 
59     /// Creates a new `AsyncBufWriter<W>` with a specific buffer capacity.
60     ///
61     /// # Examples
62     ///
63     /// ```no run
64     /// use ylong_runtime::fs::File;
65     ///
66     /// async fn main() -> std::io::Result<()> {
67     ///     use ylong_runtime::io::AsyncBufWriter;
68     ///     let f = File::open("test.txt").await?;
69     ///     let reader = AsyncBufWriter::with_capacity(1000, f);
70     ///     Ok(())
71     /// }
with_capacitynull72     pub fn with_capacity(cap: usize, inner: W) -> AsyncBufWriter<W> {
73         AsyncBufWriter {
74             inner,
75             buf: Vec::with_capacity(cap),
76             written: 0,
77         }
78     }
79 
80     /// Gets a reference to the inner writer.
81     ///
82     /// # Examples
83     ///
84     /// ```no run
85     /// use ylong_runtime::fs::File;
86     ///
87     /// async fn main() -> std::io::Result<()> {
88     ///     use ylong_runtime::io::AsyncBufWriter;
89     ///     let f = File::open("test.txt").await?;
90     ///     let writer = AsyncBufWriter::new(f);
91     ///     let writer_ref = writer.get_ref();
92     ///     Ok(())
93     /// }
94     /// ```
get_refnull95     pub fn get_ref(&self) -> &W {
96         &self.inner
97     }
98 
99     /// Gets the mutable reference to the inner writer.
100     ///
101     /// # Examples
102     ///
103     /// ```no run
104     /// use ylong_runtime::fs::File;
105     ///
106     /// async fn main() -> std::io::Result<()> {
107     ///     use ylong_runtime::io::AsyncBufWriter;
108     ///     let f = File::open("test.txt").await?;
109     ///     let mut writer = AsyncBufWriter::new(f);
110     ///     let writer_ref = writer.get_mut();
111     ///     Ok(())
112     /// }
113     /// ```
get_mutnull114     pub fn get_mut(&mut self) -> &mut W {
115         &mut self.inner
116     }
117 
118     /// Unwraps this `AsyncBufWriter<R>`, returning the underlying writer.
119     ///
120     /// Any leftover data inside the internal buffer of the `AsyncBufWriter` is
121     /// lost.
into_innernull122     pub fn into_inner(self) -> W {
123         self.inner
124     }
125 
126     /// Returns a reference to the internally buffered data.
127     ///
128     /// Only returns the filled part of the buffer instead of the whole buffer.
129     ///
130     /// # Examples
131     ///
132     /// ```no run
133     /// use ylong_runtime::fs::File;
134     ///
135     /// async fn main() -> std::io::Result<()> {
136     ///     use ylong_runtime::io::AsyncBufWriter;
137     ///     let f = File::open("test.txt").await?;
138     ///     let writer = AsyncBufWriter::new(f);
139     ///     let writer_buf = writer.buffer();
140     ///     assert!(writer_buf.is_empty());
141     ///     Ok(())
142     /// }
143     /// ```
buffernull144     pub fn buffer(&self) -> &[u8] {
145         &self.buf
146     }
147 
148     /// Returns the number of bytes the internal buffer can hold without
149     /// flushing.
150     ///
151     /// # Examples
152     ///
153     /// ```no run
154     /// use std::net::SocketAddr;
155     ///
156     /// use ylong_runtime::io::AsyncBufWriter;
157     /// use ylong_runtime::net::TcpStream;
158     ///
159     /// async fn async_io() -> std::io::Result<()> {
160     ///     let addr: SocketAddr = "127.0.0.1:8081".parse().unwrap();
161     ///     let buf_writer = AsyncBufWriter::new(TcpStream::connect(addr).await.unwrap());
162     ///     // Checks the capacity of the inner buffer
163     ///     let capacity = buf_writer.capacity();
164     ///     // Calculates how many bytes can be written without flushing
165     ///     let without_flush = capacity - buf_writer.buffer().len();
166     ///     Ok(())
167     /// }
168     /// ```
capacitynull169     pub fn capacity(&self) -> usize {
170         self.buf.capacity()
171     }
172 
flushnull173     fn flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174         let this = unsafe { self.get_unchecked_mut() };
175         let len = this.buf.len();
176         let mut res = Ok(());
177         while this.written < len {
178             unsafe {
179                 match poll_ready!(
180                     Pin::new_unchecked(&mut this.inner).poll_write(cx, &this.buf[this.written..])
181                 ) {
182                     Ok(0) => {
183                         res = Err(io::Error::new(
184                             io::ErrorKind::WriteZero,
185                             "unwritten data remains in buf",
186                         ));
187                         break;
188                     }
189                     Ok(n) => this.written += n,
190                     Err(e) => {
191                         res = Err(e);
192                         break;
193                     }
194                 }
195             }
196         }
197         if this.written > 0 {
198             this.buf.drain(..this.written);
199             this.written = 0;
200         }
201         Poll::Ready(res)
202     }
203 }
204 
205 impl<W: AsyncWrite> AsyncWrite for AsyncBufWriter<W> {
poll_writenull206     fn poll_write(
207         mut self: Pin<&mut Self>,
208         cx: &mut Context<'_>,
209         buf: &[u8],
210     ) -> Poll<io::Result<usize>> {
211         if self.buf.len() + buf.len() > self.buf.capacity() {
212             poll_ready!(self.as_mut().flush(cx))?;
213         }
214 
215         let this = unsafe { self.get_unchecked_mut() };
216         if buf.len() >= this.buf.capacity() {
217             unsafe { Pin::new_unchecked(&mut this.inner).poll_write(cx, buf) }
218         } else {
219             this.buf.extend_from_slice(buf);
220             Poll::Ready(Ok(buf.len()))
221         }
222     }
223 
poll_write_vectorednull224     fn poll_write_vectored(
225         mut self: Pin<&mut Self>,
226         cx: &mut Context<'_>,
227         mut bufs: &[IoSlice<'_>],
228     ) -> Poll<io::Result<usize>> {
229         if self.inner.is_write_vectored() {
230             let mut len: usize = 0;
231             for buf in bufs {
232                 len = len.saturating_add(buf.len());
233             }
234             if len + self.buf.len() > self.buf.capacity() {
235                 poll_ready!(self.as_mut().flush(cx))?;
236             }
237 
238             let this = unsafe { self.get_unchecked_mut() };
239             if len >= this.buf.capacity() {
240                 unsafe { Pin::new_unchecked(&mut this.inner).poll_write_vectored(cx, bufs) }
241             } else {
242                 for buf in bufs {
243                     this.buf.extend_from_slice(buf);
244                 }
245                 Poll::Ready(Ok(len))
246             }
247         } else {
248             if bufs.is_empty() {
249                 return Poll::Ready(Ok(0));
250             }
251             while bufs[0].len() == 0 {
252                 bufs = &bufs[1..];
253             }
254             let mut len = bufs[0].len();
255             if len + self.buf.len() > self.buf.capacity() {
256                 poll_ready!(self.as_mut().flush(cx))?;
257             }
258 
259             let this = unsafe { self.get_unchecked_mut() };
260             if len >= this.buf.capacity() {
261                 return unsafe { Pin::new_unchecked(&mut this.inner).poll_write(cx, &bufs[0]) };
262             } else {
263                 this.buf.extend_from_slice(&bufs[0]);
264                 bufs = &bufs[1..];
265             }
266             for buf in bufs {
267                 if buf.len() + this.buf.len() >= this.buf.capacity() {
268                     break;
269                 } else {
270                     this.buf.extend_from_slice(buf);
271                     len += buf.len()
272                 }
273             }
274             Poll::Ready(Ok(len))
275         }
276     }
277 
is_write_vectorednull278     fn is_write_vectored(&self) -> bool {
279         true
280     }
281 
poll_flushnull282     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
283         poll_ready!(self.as_mut().flush(cx))?;
284         let this = unsafe { self.get_unchecked_mut() };
285         unsafe { Pin::new_unchecked(&mut this.inner).poll_flush(cx) }
286     }
287 
poll_shutdownnull288     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
289         poll_ready!(self.as_mut().flush(cx))?;
290         let this = unsafe { self.get_unchecked_mut() };
291         unsafe { Pin::new_unchecked(&mut this.inner).poll_shutdown(cx) }
292     }
293 }
294 
295 impl<R: AsyncWrite + AsyncSeek> AsyncSeek for AsyncBufWriter<R> {
poll_seeknull296     fn poll_seek(
297         mut self: Pin<&mut Self>,
298         cx: &mut Context<'_>,
299         pos: SeekFrom,
300     ) -> Poll<io::Result<u64>> {
301         poll_ready!(self.as_mut().flush(cx))?;
302         let this = unsafe { self.get_unchecked_mut() };
303         unsafe { Pin::new_unchecked(&mut this.inner).poll_seek(cx, pos) }
304     }
305 }
306 
307 impl<W: AsyncWrite + AsyncRead> AsyncRead for AsyncBufWriter<W> {
poll_readnull308     fn poll_read(
309         self: Pin<&mut Self>,
310         cx: &mut Context<'_>,
311         buf: &mut ReadBuf<'_>,
312     ) -> Poll<io::Result<()>> {
313         let this = unsafe { self.get_unchecked_mut() };
314         unsafe { Pin::new_unchecked(&mut this.inner).poll_read(cx, buf) }
315     }
316 }
317 
318 impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for AsyncBufWriter<W> {
poll_fill_bufnull319     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
320         let this = unsafe { self.get_unchecked_mut() };
321         unsafe { Pin::new_unchecked(&mut this.inner).poll_fill_buf(cx) }
322     }
323 
consumenull324     fn consume(self: Pin<&mut Self>, amt: usize) {
325         let this = unsafe { self.get_unchecked_mut() };
326         unsafe { Pin::new_unchecked(&mut this.inner).consume(amt) }
327     }
328 }
329