// Copyright 2017 Amanieu d'Antras // // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be // copied, modified, or distributed except according to those terms. //! Per-object thread-local storage //! //! This library provides the `ThreadLocal` type which allows a separate copy of //! an object to be used for each thread. This allows for per-object //! thread-local storage, unlike the standard library's `thread_local!` macro //! which only allows static thread-local storage. //! //! Per-thread objects are not destroyed when a thread exits. Instead, objects //! are only destroyed when the `ThreadLocal` containing them is destroyed. //! //! You can also iterate over the thread-local values of all thread in a //! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can //! only be done if you have mutable access to the `ThreadLocal` object, which //! guarantees that you are the only thread currently accessing it. //! //! Note that since thread IDs are recycled when a thread exits, it is possible //! for one thread to retrieve the object of another thread. Since this can only //! occur after a thread has exited this does not lead to any race conditions. //! //! # Examples //! //! Basic usage of `ThreadLocal`: //! //! ```rust //! use thread_local::ThreadLocal; //! let tls: ThreadLocal = ThreadLocal::new(); //! assert_eq!(tls.get(), None); //! assert_eq!(tls.get_or(|| 5), &5); //! assert_eq!(tls.get(), Some(&5)); //! ``` //! //! Combining thread-local values into a single result: //! //! ```rust //! use thread_local::ThreadLocal; //! use std::sync::Arc; //! use std::cell::Cell; //! use std::thread; //! //! let tls = Arc::new(ThreadLocal::new()); //! //! // Create a bunch of threads to do stuff //! for _ in 0..5 { //! let tls2 = tls.clone(); //! thread::spawn(move || { //! // Increment a counter to count some event... //! let cell = tls2.get_or(|| Cell::new(0)); //! cell.set(cell.get() + 1); //! }).join().unwrap(); //! } //! //! // Once all threads are done, collect the counter values and return the //! // sum of all thread-local counter values. //! let tls = Arc::try_unwrap(tls).unwrap(); //! let total = tls.into_iter().fold(0, |x, y| x + y.get()); //! assert_eq!(total, 5); //! ``` #![warn(missing_docs)] #![allow(clippy::mutex_atomic)] #[macro_use] extern crate lazy_static; mod cached; mod thread_id; mod unreachable; #[allow(deprecated)] pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; use std::cell::UnsafeCell; use std::fmt; use std::marker::PhantomData; use std::mem; use std::panic::UnwindSafe; use std::ptr; use std::sync::atomic::{AtomicPtr, Ordering}; use std::sync::Mutex; use thread_id::Thread; use unreachable::{UncheckedOptionExt, UncheckedResultExt}; // Use usize::BITS once it has stabilized and the MSRV has been bumped. #[cfg(target_pointer_width = "16")] const POINTER_WIDTH: u8 = 16; #[cfg(target_pointer_width = "32")] const POINTER_WIDTH: u8 = 32; #[cfg(target_pointer_width = "64")] const POINTER_WIDTH: u8 = 64; /// The total number of buckets stored in each thread local. const BUCKETS: usize = (POINTER_WIDTH + 1) as usize; /// Thread-local variable wrapper /// /// See the [module-level documentation](index.html) for more. pub struct ThreadLocal { /// The buckets in the thread local. The nth bucket contains `2^(n-1)` /// elements. Each bucket is lazily allocated. buckets: [AtomicPtr>>; BUCKETS], /// Lock used to guard against concurrent modifications. This is taken when /// there is a possibility of allocating a new bucket, which only occurs /// when inserting values. This also guards the counter for the total number /// of values in the thread local. lock: Mutex, } // ThreadLocal is always Sync, even if T isn't unsafe impl Sync for ThreadLocal {} impl Default for ThreadLocal { fn default() -> ThreadLocal { ThreadLocal::new() } } impl Drop for ThreadLocal { fn drop(&mut self) { let mut bucket_size = 1; // Free each non-null bucket for (i, bucket) in self.buckets.iter_mut().enumerate() { let bucket_ptr = *bucket.get_mut(); let this_bucket_size = bucket_size; if i != 0 { bucket_size <<= 1; } if bucket_ptr.is_null() { continue; } unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, this_bucket_size)) }; } } } impl ThreadLocal { /// Creates a new empty `ThreadLocal`. pub fn new() -> ThreadLocal { Self::with_capacity(2) } /// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads /// access the thread local it will never reallocate. The capacity may be rounded up to the /// nearest power of two. pub fn with_capacity(capacity: usize) -> ThreadLocal { let allocated_buckets = capacity .checked_sub(1) .map(|c| usize::from(POINTER_WIDTH) - (c.leading_zeros() as usize) + 1) .unwrap_or(0); let mut buckets = [ptr::null_mut(); BUCKETS]; let mut bucket_size = 1; for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() { *bucket = allocate_bucket::(bucket_size); if i != 0 { bucket_size <<= 1; } } ThreadLocal { // Safety: AtomicPtr has the same representation as a pointer and arrays have the same // representation as a sequence of their inner type. buckets: unsafe { mem::transmute(buckets) }, lock: Mutex::new(0), } } /// Returns the element for the current thread, if it exists. pub fn get(&self) -> Option<&T> { let thread = thread_id::get(); self.get_inner(thread) } /// Returns the element for the current thread, or creates it if it doesn't /// exist. pub fn get_or(&self, create: F) -> &T where F: FnOnce() -> T, { unsafe { self.get_or_try(|| Ok::(create())) .unchecked_unwrap_ok() } } /// Returns the element for the current thread, or creates it if it doesn't /// exist. If `create` fails, that error is returned and no element is /// added. pub fn get_or_try(&self, create: F) -> Result<&T, E> where F: FnOnce() -> Result, { let thread = thread_id::get(); match self.get_inner(thread) { Some(x) => Ok(x), None => Ok(self.insert(thread, create()?)), } } fn get_inner(&self, thread: Thread) -> Option<&T> { let bucket_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire); if bucket_ptr.is_null() { return None; } unsafe { (&*(&*bucket_ptr.add(thread.index)).get()).as_ref() } } #[cold] fn insert(&self, thread: Thread, data: T) -> &T { // Lock the Mutex to ensure only a single thread is allocating buckets at once let mut count = self.lock.lock().unwrap(); *count += 1; let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire); let bucket_ptr = if bucket_ptr.is_null() { // Allocate a new bucket let bucket_ptr = allocate_bucket(thread.bucket_size); bucket_atomic_ptr.store(bucket_ptr, Ordering::Release); bucket_ptr } else { bucket_ptr }; drop(count); // Insert the new element into the bucket unsafe { let value_ptr = (&*bucket_ptr.add(thread.index)).get(); *value_ptr = Some(data); (&*value_ptr).as_ref().unchecked_unwrap() } } fn raw_iter(&mut self) -> RawIter { RawIter { remaining: *self.lock.get_mut().unwrap(), buckets: unsafe { *(&self.buckets as *const _ as *const [*const UnsafeCell>; BUCKETS]) }, bucket: 0, bucket_size: 1, index: 0, } } /// Returns a mutable iterator over the local values of all threads in /// unspecified order. /// /// Since this call borrows the `ThreadLocal` mutably, this operation can /// be done safely---the mutable borrow statically guarantees no other /// threads are currently accessing their associated values. pub fn iter_mut(&mut self) -> IterMut { IterMut { raw: self.raw_iter(), marker: PhantomData, } } /// Removes all thread-specific values from the `ThreadLocal`, effectively /// reseting it to its original state. /// /// Since this call borrows the `ThreadLocal` mutably, this operation can /// be done safely---the mutable borrow statically guarantees no other /// threads are currently accessing their associated values. pub fn clear(&mut self) { *self = ThreadLocal::new(); } } impl IntoIterator for ThreadLocal { type Item = T; type IntoIter = IntoIter; fn into_iter(mut self) -> IntoIter { IntoIter { raw: self.raw_iter(), _thread_local: self, } } } impl<'a, T: Send + 'a> IntoIterator for &'a mut ThreadLocal { type Item = &'a mut T; type IntoIter = IterMut<'a, T>; fn into_iter(self) -> IterMut<'a, T> { self.iter_mut() } } impl ThreadLocal { /// Returns the element for the current thread, or creates a default one if /// it doesn't exist. pub fn get_or_default(&self) -> &T { self.get_or(Default::default) } } impl fmt::Debug for ThreadLocal { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) } } impl UnwindSafe for ThreadLocal {} struct RawIter { remaining: usize, buckets: [*const UnsafeCell>; BUCKETS], bucket: usize, bucket_size: usize, index: usize, } impl Iterator for RawIter { type Item = *mut Option; fn next(&mut self) -> Option { if self.remaining == 0 { return None; } loop { let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) }; if !bucket.is_null() { while self.index < self.bucket_size { let item = unsafe { (&*bucket.add(self.index)).get() }; self.index += 1; if unsafe { &*item }.is_some() { self.remaining -= 1; return Some(item); } } } if self.bucket != 0 { self.bucket_size <<= 1; } self.bucket += 1; self.index = 0; } } fn size_hint(&self) -> (usize, Option) { (self.remaining, Some(self.remaining)) } } /// Mutable iterator over the contents of a `ThreadLocal`. pub struct IterMut<'a, T: Send + 'a> { raw: RawIter, marker: PhantomData<&'a mut ThreadLocal>, } impl<'a, T: Send + 'a> Iterator for IterMut<'a, T> { type Item = &'a mut T; fn next(&mut self) -> Option<&'a mut T> { self.raw .next() .map(|x| unsafe { &mut *(*x).as_mut().unchecked_unwrap() }) } fn size_hint(&self) -> (usize, Option) { self.raw.size_hint() } } impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {} /// An iterator that moves out of a `ThreadLocal`. pub struct IntoIter { raw: RawIter, _thread_local: ThreadLocal, } impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { self.raw .next() .map(|x| unsafe { (*x).take().unchecked_unwrap() }) } fn size_hint(&self) -> (usize, Option) { self.raw.size_hint() } } impl ExactSizeIterator for IntoIter {} fn allocate_bucket(size: usize) -> *mut UnsafeCell> { Box::into_raw( (0..size) .map(|_| UnsafeCell::new(None::)) .collect::>() .into_boxed_slice(), ) as *mut _ } #[cfg(test)] mod tests { use super::ThreadLocal; use std::cell::RefCell; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::thread; fn make_create() -> Arc usize + Send + Sync> { let count = AtomicUsize::new(0); Arc::new(move || count.fetch_add(1, Relaxed)) } #[test] fn same_thread() { let create = make_create(); let mut tls = ThreadLocal::new(); assert_eq!(None, tls.get()); assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); assert_eq!(0, *tls.get_or(|| create())); assert_eq!(Some(&0), tls.get()); assert_eq!(0, *tls.get_or(|| create())); assert_eq!(Some(&0), tls.get()); assert_eq!(0, *tls.get_or(|| create())); assert_eq!(Some(&0), tls.get()); assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); tls.clear(); assert_eq!(None, tls.get()); } #[test] fn different_thread() { let create = make_create(); let tls = Arc::new(ThreadLocal::new()); assert_eq!(None, tls.get()); assert_eq!(0, *tls.get_or(|| create())); assert_eq!(Some(&0), tls.get()); let tls2 = tls.clone(); let create2 = create.clone(); thread::spawn(move || { assert_eq!(None, tls2.get()); assert_eq!(1, *tls2.get_or(|| create2())); assert_eq!(Some(&1), tls2.get()); }) .join() .unwrap(); assert_eq!(Some(&0), tls.get()); assert_eq!(0, *tls.get_or(|| create())); } #[test] fn iter() { let tls = Arc::new(ThreadLocal::new()); tls.get_or(|| Box::new(1)); let tls2 = tls.clone(); thread::spawn(move || { tls2.get_or(|| Box::new(2)); let tls3 = tls2.clone(); thread::spawn(move || { tls3.get_or(|| Box::new(3)); }) .join() .unwrap(); drop(tls2); }) .join() .unwrap(); let mut tls = Arc::try_unwrap(tls).unwrap(); let mut v = tls.iter_mut().map(|x| **x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v); let mut v = tls.into_iter().map(|x| *x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v); } #[test] fn is_sync() { fn foo() {} foo::>(); foo::>>(); } }