1// There's a lot of scary concurrent code in this module, but it is copied from 2// `std::sync::Once` with two changes: 3// * no poisoning 4// * init function can fail 5 6use std::{ 7 cell::{Cell, UnsafeCell}, 8 marker::PhantomData, 9 panic::{RefUnwindSafe, UnwindSafe}, 10 sync::atomic::{AtomicBool, AtomicPtr, Ordering}, 11 thread::{self, Thread}, 12}; 13 14#[derive(Debug)] 15pub(crate) struct OnceCell<T> { 16 // This `queue` field is the core of the implementation. It encodes two 17 // pieces of information: 18 // 19 // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`) 20 // * Linked list of threads waiting for the current cell. 21 // 22 // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states 23 // allow waiters. 24 queue: AtomicPtr<Waiter>, 25 _marker: PhantomData<*mut Waiter>, 26 value: UnsafeCell<Option<T>>, 27} 28 29// Why do we need `T: Send`? 30// Thread A creates a `OnceCell` and shares it with 31// scoped thread B, which fills the cell, which is 32// then destroyed by A. That is, destructor observes 33// a sent value. 34unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} 35unsafe impl<T: Send> Send for OnceCell<T> {} 36 37impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} 38impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} 39 40impl<T> OnceCell<T> { 41 pub(crate) const fn new() -> OnceCell<T> { 42 OnceCell { 43 queue: AtomicPtr::new(INCOMPLETE_PTR), 44 _marker: PhantomData, 45 value: UnsafeCell::new(None), 46 } 47 } 48 49 pub(crate) const fn with_value(value: T) -> OnceCell<T> { 50 OnceCell { 51 queue: AtomicPtr::new(COMPLETE_PTR), 52 _marker: PhantomData, 53 value: UnsafeCell::new(Some(value)), 54 } 55 } 56 57 /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). 58 #[inline] 59 pub(crate) fn is_initialized(&self) -> bool { 60 // An `Acquire` load is enough because that makes all the initialization 61 // operations visible to us, and, this being a fast path, weaker 62 // ordering helps with performance. This `Acquire` synchronizes with 63 // `SeqCst` operations on the slow path. 64 self.queue.load(Ordering::Acquire) == COMPLETE_PTR 65 } 66 67 /// Safety: synchronizes with store to value via SeqCst read from state, 68 /// writes value only once because we never get to INCOMPLETE state after a 69 /// successful write. 70 #[cold] 71 pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> 72 where 73 F: FnOnce() -> Result<T, E>, 74 { 75 let mut f = Some(f); 76 let mut res: Result<(), E> = Ok(()); 77 let slot: *mut Option<T> = self.value.get(); 78 initialize_or_wait( 79 &self.queue, 80 Some(&mut || { 81 let f = unsafe { crate::unwrap_unchecked(f.take()) }; 82 match f() { 83 Ok(value) => { 84 unsafe { *slot = Some(value) }; 85 true 86 } 87 Err(err) => { 88 res = Err(err); 89 false 90 } 91 } 92 }), 93 ); 94 res 95 } 96 97 #[cold] 98 pub(crate) fn wait(&self) { 99 initialize_or_wait(&self.queue, None); 100 } 101 102 /// Get the reference to the underlying value, without checking if the cell 103 /// is initialized. 104 /// 105 /// # Safety 106 /// 107 /// Caller must ensure that the cell is in initialized state, and that 108 /// the contents are acquired by (synchronized to) this thread. 109 pub(crate) unsafe fn get_unchecked(&self) -> &T { 110 debug_assert!(self.is_initialized()); 111 let slot = &*self.value.get(); 112 crate::unwrap_unchecked(slot.as_ref()) 113 } 114 115 /// Gets the mutable reference to the underlying value. 116 /// Returns `None` if the cell is empty. 117 pub(crate) fn get_mut(&mut self) -> Option<&mut T> { 118 // Safe b/c we have a unique access. 119 unsafe { &mut *self.value.get() }.as_mut() 120 } 121 122 /// Consumes this `OnceCell`, returning the wrapped value. 123 /// Returns `None` if the cell was empty. 124 #[inline] 125 pub(crate) fn into_inner(self) -> Option<T> { 126 // Because `into_inner` takes `self` by value, the compiler statically 127 // verifies that it is not currently borrowed. 128 // So, it is safe to move out `Option<T>`. 129 self.value.into_inner() 130 } 131} 132 133// Three states that a OnceCell can be in, encoded into the lower bits of `queue` in 134// the OnceCell structure. 135const INCOMPLETE: usize = 0x0; 136const RUNNING: usize = 0x1; 137const COMPLETE: usize = 0x2; 138const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter; 139const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter; 140 141// Mask to learn about the state. All other bits are the queue of waiters if 142// this is in the RUNNING state. 143const STATE_MASK: usize = 0x3; 144 145/// Representation of a node in the linked list of waiters in the RUNNING state. 146/// A waiters is stored on the stack of the waiting threads. 147#[repr(align(4))] // Ensure the two lower bits are free to use as state bits. 148struct Waiter { 149 thread: Cell<Option<Thread>>, 150 signaled: AtomicBool, 151 next: *mut Waiter, 152} 153 154/// Drains and notifies the queue of waiters on drop. 155struct Guard<'a> { 156 queue: &'a AtomicPtr<Waiter>, 157 new_queue: *mut Waiter, 158} 159 160impl Drop for Guard<'_> { 161 fn drop(&mut self) { 162 let queue = self.queue.swap(self.new_queue, Ordering::AcqRel); 163 164 let state = strict::addr(queue) & STATE_MASK; 165 assert_eq!(state, RUNNING); 166 167 unsafe { 168 let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK); 169 while !waiter.is_null() { 170 let next = (*waiter).next; 171 let thread = (*waiter).thread.take().unwrap(); 172 (*waiter).signaled.store(true, Ordering::Release); 173 waiter = next; 174 thread.unpark(); 175 } 176 } 177 } 178} 179 180// Corresponds to `std::sync::Once::call_inner`. 181// 182// Originally copied from std, but since modified to remove poisoning and to 183// support wait. 184// 185// Note: this is intentionally monomorphic 186#[inline(never)] 187fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) { 188 let mut curr_queue = queue.load(Ordering::Acquire); 189 190 loop { 191 let curr_state = strict::addr(curr_queue) & STATE_MASK; 192 match (curr_state, &mut init) { 193 (COMPLETE, _) => return, 194 (INCOMPLETE, Some(init)) => { 195 let exchange = queue.compare_exchange( 196 curr_queue, 197 strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING), 198 Ordering::Acquire, 199 Ordering::Acquire, 200 ); 201 if let Err(new_queue) = exchange { 202 curr_queue = new_queue; 203 continue; 204 } 205 let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR }; 206 if init() { 207 guard.new_queue = COMPLETE_PTR; 208 } 209 return; 210 } 211 (INCOMPLETE, None) | (RUNNING, _) => { 212 wait(&queue, curr_queue); 213 curr_queue = queue.load(Ordering::Acquire); 214 } 215 _ => debug_assert!(false), 216 } 217 } 218} 219 220fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) { 221 let curr_state = strict::addr(curr_queue) & STATE_MASK; 222 loop { 223 let node = Waiter { 224 thread: Cell::new(Some(thread::current())), 225 signaled: AtomicBool::new(false), 226 next: strict::map_addr(curr_queue, |q| q & !STATE_MASK), 227 }; 228 let me = &node as *const Waiter as *mut Waiter; 229 230 let exchange = queue.compare_exchange( 231 curr_queue, 232 strict::map_addr(me, |q| q | curr_state), 233 Ordering::Release, 234 Ordering::Relaxed, 235 ); 236 if let Err(new_queue) = exchange { 237 if strict::addr(new_queue) & STATE_MASK != curr_state { 238 return; 239 } 240 curr_queue = new_queue; 241 continue; 242 } 243 244 while !node.signaled.load(Ordering::Acquire) { 245 thread::park(); 246 } 247 break; 248 } 249} 250 251// Polyfill of strict provenance from https://crates.io/crates/sptr. 252// 253// Use free-standing function rather than a trait to keep things simple and 254// avoid any potential conflicts with future stabile std API. 255mod strict { 256 #[must_use] 257 #[inline] 258 pub(crate) fn addr<T>(ptr: *mut T) -> usize 259 where 260 T: Sized, 261 { 262 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. 263 // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the 264 // provenance). 265 unsafe { core::mem::transmute(ptr) } 266 } 267 268 #[must_use] 269 #[inline] 270 pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T 271 where 272 T: Sized, 273 { 274 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic. 275 // 276 // In the mean-time, this operation is defined to be "as if" it was 277 // a wrapping_offset, so we can emulate it as such. This should properly 278 // restore pointer provenance even under today's compiler. 279 let self_addr = self::addr(ptr) as isize; 280 let dest_addr = addr as isize; 281 let offset = dest_addr.wrapping_sub(self_addr); 282 283 // This is the canonical desugarring of this operation, 284 // but `pointer::cast` was only stabilized in 1.38. 285 // self.cast::<u8>().wrapping_offset(offset).cast::<T>() 286 (ptr as *mut u8).wrapping_offset(offset) as *mut T 287 } 288 289 #[must_use] 290 #[inline] 291 pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T 292 where 293 T: Sized, 294 { 295 self::with_addr(ptr, f(addr(ptr))) 296 } 297} 298 299// These test are snatched from std as well. 300#[cfg(test)] 301mod tests { 302 use std::panic; 303 use std::{sync::mpsc::channel, thread}; 304 305 use super::OnceCell; 306 307 impl<T> OnceCell<T> { 308 fn init(&self, f: impl FnOnce() -> T) { 309 enum Void {} 310 let _ = self.initialize(|| Ok::<T, Void>(f())); 311 } 312 } 313 314 #[test] 315 fn smoke_once() { 316 static O: OnceCell<()> = OnceCell::new(); 317 let mut a = 0; 318 O.init(|| a += 1); 319 assert_eq!(a, 1); 320 O.init(|| a += 1); 321 assert_eq!(a, 1); 322 } 323 324 #[test] 325 fn stampede_once() { 326 static O: OnceCell<()> = OnceCell::new(); 327 static mut RUN: bool = false; 328 329 let (tx, rx) = channel(); 330 for _ in 0..10 { 331 let tx = tx.clone(); 332 thread::spawn(move || { 333 for _ in 0..4 { 334 thread::yield_now() 335 } 336 unsafe { 337 O.init(|| { 338 assert!(!RUN); 339 RUN = true; 340 }); 341 assert!(RUN); 342 } 343 tx.send(()).unwrap(); 344 }); 345 } 346 347 unsafe { 348 O.init(|| { 349 assert!(!RUN); 350 RUN = true; 351 }); 352 assert!(RUN); 353 } 354 355 for _ in 0..10 { 356 rx.recv().unwrap(); 357 } 358 } 359 360 #[test] 361 fn poison_bad() { 362 static O: OnceCell<()> = OnceCell::new(); 363 364 // poison the once 365 let t = panic::catch_unwind(|| { 366 O.init(|| panic!()); 367 }); 368 assert!(t.is_err()); 369 370 // we can subvert poisoning, however 371 let mut called = false; 372 O.init(|| { 373 called = true; 374 }); 375 assert!(called); 376 377 // once any success happens, we stop propagating the poison 378 O.init(|| {}); 379 } 380 381 #[test] 382 fn wait_for_force_to_finish() { 383 static O: OnceCell<()> = OnceCell::new(); 384 385 // poison the once 386 let t = panic::catch_unwind(|| { 387 O.init(|| panic!()); 388 }); 389 assert!(t.is_err()); 390 391 // make sure someone's waiting inside the once via a force 392 let (tx1, rx1) = channel(); 393 let (tx2, rx2) = channel(); 394 let t1 = thread::spawn(move || { 395 O.init(|| { 396 tx1.send(()).unwrap(); 397 rx2.recv().unwrap(); 398 }); 399 }); 400 401 rx1.recv().unwrap(); 402 403 // put another waiter on the once 404 let t2 = thread::spawn(|| { 405 let mut called = false; 406 O.init(|| { 407 called = true; 408 }); 409 assert!(!called); 410 }); 411 412 tx2.send(()).unwrap(); 413 414 assert!(t1.join().is_ok()); 415 assert!(t2.join().is_ok()); 416 } 417 418 #[test] 419 #[cfg(target_pointer_width = "64")] 420 fn test_size() { 421 use std::mem::size_of; 422 423 assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); 424 } 425} 426