1// Copyright (c) 2023 Huawei Device Co., Ltd.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! One-shot channel is used to send a single message from a single sender to a
15//! single receiver. The [`channel`] function returns a [`Sender`] and
16//! [`Receiver`] handle pair that controls channel.
17//!
18//! The `Sender` handle is used by the producer to send a message.
19//! The `Receiver` handle is used by the consumer to receive the message. It has
20//! implemented the `Future` trait
21//!
22//! The `send` method is not async. It can be called from non-async context.
23//!
24//! # Examples
25//!
26//! ```
27//! use ylong_runtime::sync::oneshot;
28//! async fn io_func() {
29//!     let (tx, rx) = oneshot::channel();
30//!     ylong_runtime::spawn(async move {
31//!         if let Err(_) = tx.send(6) {
32//!             println!("Receiver dropped");
33//!         }
34//!     });
35//!
36//!     match rx.await {
37//!         Ok(v) => println!("received : {:?}", v),
38//!         Err(_) => println!("Sender dropped"),
39//!     }
40//! }
41//! ```
42use std::cell::RefCell;
43use std::fmt::{Debug, Formatter};
44use std::future::Future;
45use std::pin::Pin;
46use std::sync::atomic::AtomicUsize;
47use std::sync::atomic::Ordering::{AcqRel, Acquire, Release, SeqCst};
48use std::sync::Arc;
49use std::task::Poll::{Pending, Ready};
50use std::task::{Context, Poll};
51
52use super::atomic_waker::AtomicWaker;
53use super::error::{RecvError, TryRecvError};
54
55/// Initial state.
56const INIT: usize = 0b00;
57/// Sender has sent the value.
58const SENT: usize = 0b01;
59/// Channel is closed.
60const CLOSED: usize = 0b10;
61
62/// Creates a new one-shot channel with a `Sender` and `Receiver` handle pair.
63///
64/// The `Sender` can send a single value to the `Receiver`.
65///
66/// # Examples
67///
68/// ```
69/// use ylong_runtime::sync::oneshot;
70/// async fn io_func() {
71///     let (tx, rx) = oneshot::channel();
72///     ylong_runtime::spawn(async move {
73///         if let Err(_) = tx.send(6) {
74///             println!("Receiver dropped");
75///         }
76///     });
77///
78///     match rx.await {
79///         Ok(v) => println!("received : {:?}", v),
80///         Err(_) => println!("Sender dropped"),
81///     }
82/// }
83/// ```
84pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
85    let channel = Arc::new(Channel::new());
86    let tx = Sender {
87        channel: channel.clone(),
88    };
89    let rx = Receiver { channel };
90    (tx, rx)
91}
92
93/// Sends a single value to the associated [`Receiver`].
94/// A [`Sender`] and [`Receiver`] handle pair is created by the [`channel`]
95/// function.
96///
97/// # Examples
98///
99/// ```
100/// use ylong_runtime::sync::oneshot;
101/// async fn io_func() {
102///     let (tx, rx) = oneshot::channel();
103///     ylong_runtime::spawn(async move {
104///         if let Err(_) = tx.send(6) {
105///             println!("Receiver dropped");
106///         }
107///     });
108///
109///     match rx.await {
110///         Ok(v) => println!("received : {:?}", v),
111///         Err(_) => println!("Sender dropped"),
112///     }
113/// }
114/// ```
115///
116/// The receiver will fail with a [`RecvError`] if the sender is dropped without
117/// sending a value.
118///
119/// # Examples
120///
121/// ```
122/// use ylong_runtime::sync::oneshot;
123/// async fn io_func() {
124///     let (tx, rx) = oneshot::channel::<()>();
125///     ylong_runtime::spawn(async move {
126///         drop(tx);
127///     });
128///
129///     match rx.await {
130///         Ok(v) => panic!("This won't happen"),
131///         Err(_) => println!("Sender dropped"),
132///     }
133/// }
134/// ```
135#[derive(Debug)]
136pub struct Sender<T> {
137    channel: Arc<Channel<T>>,
138}
139
140impl<T> Sender<T> {
141    /// Sends a single value to the associated [`Receiver`], returns the value
142    /// back if it fails to send.
143    ///
144    /// The sender will consume itself when calling this method. It can send a
145    /// single value in synchronous code as it doesn't need waiting.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use ylong_runtime::sync::oneshot;
151    /// async fn io_func() {
152    ///     let (tx, rx) = oneshot::channel();
153    ///     ylong_runtime::spawn(async move {
154    ///         if let Err(_) = tx.send(6) {
155    ///             println!("Receiver dropped");
156    ///         }
157    ///     });
158    ///
159    ///     match rx.await {
160    ///         Ok(v) => println!("received : {:?}", v),
161    ///         Err(_) => println!("Sender dropped"),
162    ///     }
163    /// }
164    /// ```
165    pub fn send(self, value: T) -> Result<(), T> {
166        self.channel.value.borrow_mut().replace(value);
167
168        loop {
169            match self.channel.state.load(Acquire) {
170                INIT => {
171                    if self
172                        .channel
173                        .state
174                        .compare_exchange(INIT, SENT, AcqRel, Acquire)
175                        .is_ok()
176                    {
177                        self.channel.waker.wake();
178                        return Ok(());
179                    }
180                }
181                CLOSED => {
182                    // value is stored in this function before.
183                    return Err(self.channel.take_value().unwrap());
184                }
185                _ => unreachable!(),
186            }
187        }
188    }
189
190    /// Checks whether channel is closed. if so, the sender could not
191    /// send any value anymore. It returns true if the [`Receiver`] is dropped
192    /// or calls the [`close`] method.
193    ///
194    /// [`close`]: Receiver::close
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// use ylong_runtime::sync::oneshot;
200    /// async fn io_func() {
201    ///     let (tx, rx) = oneshot::channel();
202    ///     assert!(!tx.is_closed());
203    ///
204    ///     drop(rx);
205    ///
206    ///     assert!(tx.is_closed());
207    ///     assert!(tx.send("no receive").is_err());
208    /// }
209    /// ```
210    pub fn is_closed(&self) -> bool {
211        self.channel.state.load(Acquire) == CLOSED
212    }
213}
214
215impl<T> Drop for Sender<T> {
216    fn drop(&mut self) {
217        if self.channel.state.swap(SENT, SeqCst) == INIT {
218            self.channel.waker.wake();
219        }
220    }
221}
222
223/// Receives a single value from the associated [`Sender`].
224/// A [`Sender`] and [`Receiver`] handle pair is created by the [`channel`]
225/// function.
226///
227/// There is no `recv` method to receive the message because the receiver itself
228/// implements the [`Future`] trait. To receive a value, `.await` the `Receiver`
229/// object directly.
230///
231/// # Examples
232///
233/// ```
234/// use ylong_runtime::sync::oneshot;
235/// async fn io_func() {
236///     let (tx, rx) = oneshot::channel();
237///     ylong_runtime::spawn(async move {
238///         if let Err(_) = tx.send(6) {
239///             println!("Receiver dropped");
240///         }
241///     });
242///
243///     match rx.await {
244///         Ok(v) => println!("received : {:?}", v),
245///         Err(_) => println!("Sender dropped"),
246///     }
247/// }
248/// ```
249///
250/// The receiver will fail with [`RecvError`], if the sender is dropped without
251/// sending a value.
252///
253/// # Examples
254///
255/// ```
256/// use ylong_runtime::sync::oneshot;
257/// async fn io_func() {
258///     let (tx, rx) = oneshot::channel::<u32>();
259///     ylong_runtime::spawn(async move {
260///         drop(tx);
261///     });
262///
263///     match rx.await {
264///         Ok(v) => panic!("This won't happen"),
265///         Err(_) => println!("Sender dropped"),
266///     }
267/// }
268/// ```
269#[derive(Debug)]
270pub struct Receiver<T> {
271    channel: Arc<Channel<T>>,
272}
273
274impl<T> Receiver<T> {
275    /// Attempts to receive a value from the associated [`Sender`].
276    ///
277    /// The method will still receive the result if the `Sender` gets dropped
278    /// after sending the message.
279    ///
280    /// # Return value
281    /// The function returns:
282    ///  * `Ok(T)` if receiving a value successfully.
283    ///  * `Err(TryRecvError::Empty)` if no value has been sent yet.
284    ///  * `Err(TryRecvError::Closed)` if the sender has dropped without sending
285    ///   a value, or if the message has already been received.
286    ///
287    /// # Examples
288    ///
289    /// `try_recv` before a value is sent, then after.
290    ///
291    /// ```
292    /// use ylong_runtime::sync::error::TryRecvError;
293    /// use ylong_runtime::sync::oneshot;
294    /// async fn io_func() {
295    ///     let (tx, mut rx) = oneshot::channel();
296    ///     match rx.try_recv() {
297    ///         Err(TryRecvError::Empty) => {}
298    ///         _ => panic!("This won't happen"),
299    ///     }
300    ///
301    ///     // Send a value
302    ///     tx.send("Hello").unwrap();
303    ///
304    ///     match rx.try_recv() {
305    ///         Ok(value) => assert_eq!(value, "Hello"),
306    ///         _ => panic!("This won't happen"),
307    ///     }
308    /// }
309    /// ```
310    ///
311    /// `try_recv` when the sender dropped before sending a value
312    ///
313    /// ```
314    /// use ylong_runtime::sync::error::TryRecvError;
315    /// use ylong_runtime::sync::oneshot;
316    /// async fn io_func() {
317    ///     let (tx, mut rx) = oneshot::channel::<()>();
318    ///     drop(tx);
319    ///
320    ///     match rx.try_recv() {
321    ///         Err(TryRecvError::Closed) => {}
322    ///         _ => panic!("This won't happen"),
323    ///     }
324    /// }
325    /// ```
326    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
327        match self.channel.state.load(Acquire) {
328            INIT => Err(TryRecvError::Empty),
329            SENT => self
330                .channel
331                .take_value_sent()
332                .map_err(|_| TryRecvError::Closed),
333            CLOSED => Err(TryRecvError::Closed),
334            _ => unreachable!(),
335        }
336    }
337
338    /// Closes the channel, prevents the `Sender` from sending a value.
339    ///
340    /// The `Sender` will fail to call [`send`] after the `Receiver` called
341    /// `close`. It will do nothing if the channel is already closed or the
342    /// message has been already received.
343    ///
344    /// [`send`]: Sender::send
345    /// [`try_recv`]: Receiver::try_recv
346    ///
347    /// # Examples
348    /// ```
349    /// use ylong_runtime::sync::oneshot;
350    /// async fn io_func() {
351    ///     let (tx, mut rx) = oneshot::channel();
352    ///     assert!(!tx.is_closed());
353    ///
354    ///     rx.close();
355    ///
356    ///     assert!(tx.is_closed());
357    ///     assert!(tx.send("no receive").is_err());
358    /// }
359    /// ```
360    ///
361    /// Receive a value sent **before** calling `close`
362    ///
363    /// ```
364    /// use ylong_runtime::sync::oneshot;
365    /// async fn io_func() {
366    ///     let (tx, mut rx) = oneshot::channel();
367    ///     assert!(tx.send("Hello").is_ok());
368    ///
369    ///     rx.close();
370    ///
371    ///     let msg = rx.try_recv().unwrap();
372    ///     assert_eq!(msg, "Hello");
373    /// }
374    /// ```
375    pub fn close(&mut self) {
376        let _ = self
377            .channel
378            .state
379            .compare_exchange(INIT, CLOSED, AcqRel, Acquire);
380    }
381}
382
383impl<T> Drop for Receiver<T> {
384    fn drop(&mut self) {
385        self.close();
386    }
387}
388
389impl<T> Future for Receiver<T> {
390    type Output = Result<T, RecvError>;
391
392    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
393        match self.channel.state.load(Acquire) {
394            INIT => {
395                self.channel.waker.register_by_ref(cx.waker());
396                if self.channel.state.load(Acquire) == SENT {
397                    Ready(self.channel.take_value_sent())
398                } else {
399                    Pending
400                }
401            }
402            SENT => Ready(self.channel.take_value_sent()),
403            CLOSED => Ready(Err(RecvError)),
404            _ => unreachable!(),
405        }
406    }
407}
408
409struct Channel<T> {
410    /// The state of the channel.
411    state: AtomicUsize,
412
413    /// The value passed by channel, it is set by `Sender` and read by
414    /// `Receiver`.
415    value: RefCell<Option<T>>,
416
417    /// The waker to notify the sender task or the receiver task.
418    waker: AtomicWaker,
419}
420
421impl<T> Channel<T> {
422    fn new() -> Channel<T> {
423        Channel {
424            state: AtomicUsize::new(INIT),
425            value: RefCell::new(None),
426            waker: AtomicWaker::new(),
427        }
428    }
429
430    fn take_value_sent(&self) -> Result<T, RecvError> {
431        match self.take_value() {
432            Some(val) => {
433                self.state.store(CLOSED, Release);
434                Ok(val)
435            }
436            None => Err(RecvError),
437        }
438    }
439
440    fn take_value(&self) -> Option<T> {
441        self.value.borrow_mut().take()
442    }
443}
444
445unsafe impl<T: Send> Send for Channel<T> {}
446unsafe impl<T: Send> Sync for Channel<T> {}
447
448impl<T> Drop for Channel<T> {
449    fn drop(&mut self) {
450        self.waker.take_waker();
451    }
452}
453
454impl<T: Debug> Debug for Channel<T> {
455    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
456        f.debug_struct("Channel")
457            .field("state", &self.state.load(Acquire))
458            .finish()
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use crate::spawn;
465    use crate::sync::error::TryRecvError;
466    use crate::sync::oneshot;
467
468    /// UT test cases for `send()` and `try_recv()`.
469    ///
470    /// # Brief
471    /// 1. Call channel to create a sender and a receiver handle pair.
472    /// 2. Receiver tries receiving a message before the sender sends one.
473    /// 3. Receiver tries receiving a message after the sender sends one.
474    /// 4. Check if the test results are correct.
475    #[test]
476    fn send_try_recv() {
477        let (tx, mut rx) = oneshot::channel();
478        assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
479        tx.send("hello").unwrap();
480
481        assert_eq!(rx.try_recv().unwrap(), "hello");
482        assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
483    }
484
485    /// UT test cases for `send()` and async receive.
486    ///
487    /// # Brief
488    /// 1. Call channel to create a sender and a receiver handle pair.
489    /// 2. Sender sends message in one thread.
490    /// 3. Receiver receives message in another thread.
491    /// 4. Check if the test results are correct.
492    #[test]
493    fn send_recv_await() {
494        let (tx, rx) = oneshot::channel();
495        if tx.send(6).is_err() {
496            panic!("Receiver dropped");
497        }
498        spawn(async move {
499            match rx.await {
500                Ok(v) => assert_eq!(v, 6),
501                Err(_) => panic!("Sender dropped"),
502            }
503        });
504    }
505
506    /// UT test cases for `is_closed()` and `close`.
507    ///
508    /// # Brief
509    /// 1. Call channel to create a sender and a receiver handle pair.
510    /// 2. Check whether the sender is closed.
511    /// 3. Close the receiver.
512    /// 4. Check whether the receiver will receive the message sent before it
513    ///    closed.
514    /// 5. Check if the test results are correct.
515    #[test]
516    fn close_rx() {
517        let (tx, mut rx) = oneshot::channel();
518        assert!(!tx.is_closed());
519        rx.close();
520
521        assert!(tx.is_closed());
522        assert!(tx.send("never received").is_err());
523
524        let (tx, mut rx) = oneshot::channel();
525        assert!(tx.send("will receive").is_ok());
526
527        rx.close();
528
529        let msg = rx.try_recv().unwrap();
530        assert_eq!(msg, "will receive");
531    }
532}
533