1// There's a lot of scary concurrent code in this module, but it is copied from
2// `std::sync::Once` with two changes:
3//   * no poisoning
4//   * init function can fail
5
6use std::{
7    cell::{Cell, UnsafeCell},
8    marker::PhantomData,
9    panic::{RefUnwindSafe, UnwindSafe},
10    sync::atomic::{AtomicBool, AtomicPtr, Ordering},
11    thread::{self, Thread},
12};
13
14#[derive(Debug)]
15pub(crate) struct OnceCell<T> {
16    // This `queue` field is the core of the implementation. It encodes two
17    // pieces of information:
18    //
19    // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
20    // * Linked list of threads waiting for the current cell.
21    //
22    // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
23    // allow waiters.
24    queue: AtomicPtr<Waiter>,
25    _marker: PhantomData<*mut Waiter>,
26    value: UnsafeCell<Option<T>>,
27}
28
29// Why do we need `T: Send`?
30// Thread A creates a `OnceCell` and shares it with
31// scoped thread B, which fills the cell, which is
32// then destroyed by A. That is, destructor observes
33// a sent value.
34unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
35unsafe impl<T: Send> Send for OnceCell<T> {}
36
37impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
38impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
39
40impl<T> OnceCell<T> {
41    pub(crate) const fn new() -> OnceCell<T> {
42        OnceCell {
43            queue: AtomicPtr::new(INCOMPLETE_PTR),
44            _marker: PhantomData,
45            value: UnsafeCell::new(None),
46        }
47    }
48
49    pub(crate) const fn with_value(value: T) -> OnceCell<T> {
50        OnceCell {
51            queue: AtomicPtr::new(COMPLETE_PTR),
52            _marker: PhantomData,
53            value: UnsafeCell::new(Some(value)),
54        }
55    }
56
57    /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
58    #[inline]
59    pub(crate) fn is_initialized(&self) -> bool {
60        // An `Acquire` load is enough because that makes all the initialization
61        // operations visible to us, and, this being a fast path, weaker
62        // ordering helps with performance. This `Acquire` synchronizes with
63        // `SeqCst` operations on the slow path.
64        self.queue.load(Ordering::Acquire) == COMPLETE_PTR
65    }
66
67    /// Safety: synchronizes with store to value via SeqCst read from state,
68    /// writes value only once because we never get to INCOMPLETE state after a
69    /// successful write.
70    #[cold]
71    pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
72    where
73        F: FnOnce() -> Result<T, E>,
74    {
75        let mut f = Some(f);
76        let mut res: Result<(), E> = Ok(());
77        let slot: *mut Option<T> = self.value.get();
78        initialize_or_wait(
79            &self.queue,
80            Some(&mut || {
81                let f = unsafe { crate::unwrap_unchecked(f.take()) };
82                match f() {
83                    Ok(value) => {
84                        unsafe { *slot = Some(value) };
85                        true
86                    }
87                    Err(err) => {
88                        res = Err(err);
89                        false
90                    }
91                }
92            }),
93        );
94        res
95    }
96
97    #[cold]
98    pub(crate) fn wait(&self) {
99        initialize_or_wait(&self.queue, None);
100    }
101
102    /// Get the reference to the underlying value, without checking if the cell
103    /// is initialized.
104    ///
105    /// # Safety
106    ///
107    /// Caller must ensure that the cell is in initialized state, and that
108    /// the contents are acquired by (synchronized to) this thread.
109    pub(crate) unsafe fn get_unchecked(&self) -> &T {
110        debug_assert!(self.is_initialized());
111        let slot = &*self.value.get();
112        crate::unwrap_unchecked(slot.as_ref())
113    }
114
115    /// Gets the mutable reference to the underlying value.
116    /// Returns `None` if the cell is empty.
117    pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
118        // Safe b/c we have a unique access.
119        unsafe { &mut *self.value.get() }.as_mut()
120    }
121
122    /// Consumes this `OnceCell`, returning the wrapped value.
123    /// Returns `None` if the cell was empty.
124    #[inline]
125    pub(crate) fn into_inner(self) -> Option<T> {
126        // Because `into_inner` takes `self` by value, the compiler statically
127        // verifies that it is not currently borrowed.
128        // So, it is safe to move out `Option<T>`.
129        self.value.into_inner()
130    }
131}
132
133// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
134// the OnceCell structure.
135const INCOMPLETE: usize = 0x0;
136const RUNNING: usize = 0x1;
137const COMPLETE: usize = 0x2;
138const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
139const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
140
141// Mask to learn about the state. All other bits are the queue of waiters if
142// this is in the RUNNING state.
143const STATE_MASK: usize = 0x3;
144
145/// Representation of a node in the linked list of waiters in the RUNNING state.
146/// A waiters is stored on the stack of the waiting threads.
147#[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
148struct Waiter {
149    thread: Cell<Option<Thread>>,
150    signaled: AtomicBool,
151    next: *mut Waiter,
152}
153
154/// Drains and notifies the queue of waiters on drop.
155struct Guard<'a> {
156    queue: &'a AtomicPtr<Waiter>,
157    new_queue: *mut Waiter,
158}
159
160impl Drop for Guard<'_> {
161    fn drop(&mut self) {
162        let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
163
164        let state = strict::addr(queue) & STATE_MASK;
165        assert_eq!(state, RUNNING);
166
167        unsafe {
168            let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
169            while !waiter.is_null() {
170                let next = (*waiter).next;
171                let thread = (*waiter).thread.take().unwrap();
172                (*waiter).signaled.store(true, Ordering::Release);
173                waiter = next;
174                thread.unpark();
175            }
176        }
177    }
178}
179
180// Corresponds to `std::sync::Once::call_inner`.
181//
182// Originally copied from std, but since modified to remove poisoning and to
183// support wait.
184//
185// Note: this is intentionally monomorphic
186#[inline(never)]
187fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
188    let mut curr_queue = queue.load(Ordering::Acquire);
189
190    loop {
191        let curr_state = strict::addr(curr_queue) & STATE_MASK;
192        match (curr_state, &mut init) {
193            (COMPLETE, _) => return,
194            (INCOMPLETE, Some(init)) => {
195                let exchange = queue.compare_exchange(
196                    curr_queue,
197                    strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
198                    Ordering::Acquire,
199                    Ordering::Acquire,
200                );
201                if let Err(new_queue) = exchange {
202                    curr_queue = new_queue;
203                    continue;
204                }
205                let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
206                if init() {
207                    guard.new_queue = COMPLETE_PTR;
208                }
209                return;
210            }
211            (INCOMPLETE, None) | (RUNNING, _) => {
212                wait(&queue, curr_queue);
213                curr_queue = queue.load(Ordering::Acquire);
214            }
215            _ => debug_assert!(false),
216        }
217    }
218}
219
220fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
221    let curr_state = strict::addr(curr_queue) & STATE_MASK;
222    loop {
223        let node = Waiter {
224            thread: Cell::new(Some(thread::current())),
225            signaled: AtomicBool::new(false),
226            next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
227        };
228        let me = &node as *const Waiter as *mut Waiter;
229
230        let exchange = queue.compare_exchange(
231            curr_queue,
232            strict::map_addr(me, |q| q | curr_state),
233            Ordering::Release,
234            Ordering::Relaxed,
235        );
236        if let Err(new_queue) = exchange {
237            if strict::addr(new_queue) & STATE_MASK != curr_state {
238                return;
239            }
240            curr_queue = new_queue;
241            continue;
242        }
243
244        while !node.signaled.load(Ordering::Acquire) {
245            thread::park();
246        }
247        break;
248    }
249}
250
251// Polyfill of strict provenance from https://crates.io/crates/sptr.
252//
253// Use free-standing function rather than a trait to keep things simple and
254// avoid any potential conflicts with future stabile std API.
255mod strict {
256    #[must_use]
257    #[inline]
258    pub(crate) fn addr<T>(ptr: *mut T) -> usize
259    where
260        T: Sized,
261    {
262        // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
263        // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
264        // provenance).
265        unsafe { core::mem::transmute(ptr) }
266    }
267
268    #[must_use]
269    #[inline]
270    pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
271    where
272        T: Sized,
273    {
274        // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
275        //
276        // In the mean-time, this operation is defined to be "as if" it was
277        // a wrapping_offset, so we can emulate it as such. This should properly
278        // restore pointer provenance even under today's compiler.
279        let self_addr = self::addr(ptr) as isize;
280        let dest_addr = addr as isize;
281        let offset = dest_addr.wrapping_sub(self_addr);
282
283        // This is the canonical desugarring of this operation,
284        // but `pointer::cast` was only stabilized in 1.38.
285        // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
286        (ptr as *mut u8).wrapping_offset(offset) as *mut T
287    }
288
289    #[must_use]
290    #[inline]
291    pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
292    where
293        T: Sized,
294    {
295        self::with_addr(ptr, f(addr(ptr)))
296    }
297}
298
299// These test are snatched from std as well.
300#[cfg(test)]
301mod tests {
302    use std::panic;
303    use std::{sync::mpsc::channel, thread};
304
305    use super::OnceCell;
306
307    impl<T> OnceCell<T> {
308        fn init(&self, f: impl FnOnce() -> T) {
309            enum Void {}
310            let _ = self.initialize(|| Ok::<T, Void>(f()));
311        }
312    }
313
314    #[test]
315    fn smoke_once() {
316        static O: OnceCell<()> = OnceCell::new();
317        let mut a = 0;
318        O.init(|| a += 1);
319        assert_eq!(a, 1);
320        O.init(|| a += 1);
321        assert_eq!(a, 1);
322    }
323
324    #[test]
325    fn stampede_once() {
326        static O: OnceCell<()> = OnceCell::new();
327        static mut RUN: bool = false;
328
329        let (tx, rx) = channel();
330        for _ in 0..10 {
331            let tx = tx.clone();
332            thread::spawn(move || {
333                for _ in 0..4 {
334                    thread::yield_now()
335                }
336                unsafe {
337                    O.init(|| {
338                        assert!(!RUN);
339                        RUN = true;
340                    });
341                    assert!(RUN);
342                }
343                tx.send(()).unwrap();
344            });
345        }
346
347        unsafe {
348            O.init(|| {
349                assert!(!RUN);
350                RUN = true;
351            });
352            assert!(RUN);
353        }
354
355        for _ in 0..10 {
356            rx.recv().unwrap();
357        }
358    }
359
360    #[test]
361    fn poison_bad() {
362        static O: OnceCell<()> = OnceCell::new();
363
364        // poison the once
365        let t = panic::catch_unwind(|| {
366            O.init(|| panic!());
367        });
368        assert!(t.is_err());
369
370        // we can subvert poisoning, however
371        let mut called = false;
372        O.init(|| {
373            called = true;
374        });
375        assert!(called);
376
377        // once any success happens, we stop propagating the poison
378        O.init(|| {});
379    }
380
381    #[test]
382    fn wait_for_force_to_finish() {
383        static O: OnceCell<()> = OnceCell::new();
384
385        // poison the once
386        let t = panic::catch_unwind(|| {
387            O.init(|| panic!());
388        });
389        assert!(t.is_err());
390
391        // make sure someone's waiting inside the once via a force
392        let (tx1, rx1) = channel();
393        let (tx2, rx2) = channel();
394        let t1 = thread::spawn(move || {
395            O.init(|| {
396                tx1.send(()).unwrap();
397                rx2.recv().unwrap();
398            });
399        });
400
401        rx1.recv().unwrap();
402
403        // put another waiter on the once
404        let t2 = thread::spawn(|| {
405            let mut called = false;
406            O.init(|| {
407                called = true;
408            });
409            assert!(!called);
410        });
411
412        tx2.send(()).unwrap();
413
414        assert!(t1.join().is_ok());
415        assert!(t2.join().is_ok());
416    }
417
418    #[test]
419    #[cfg(target_pointer_width = "64")]
420    fn test_size() {
421        use std::mem::size_of;
422
423        assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
424    }
425}
426