diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 607 |
1 files changed, 607 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9fd6d19 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,607 @@ +// Copyright 2017 Amanieu d'Antras +// +// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or +// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or +// http://opensource.org/licenses/MIT>, at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +//! Per-object thread-local storage +//! +//! This library provides the `ThreadLocal` type which allows a separate copy of +//! an object to be used for each thread. This allows for per-object +//! thread-local storage, unlike the standard library's `thread_local!` macro +//! which only allows static thread-local storage. +//! +//! Per-thread objects are not destroyed when a thread exits. Instead, objects +//! are only destroyed when the `ThreadLocal` containing them is destroyed. +//! +//! You can also iterate over the thread-local values of all thread in a +//! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can +//! only be done if you have mutable access to the `ThreadLocal` object, which +//! guarantees that you are the only thread currently accessing it. +//! +//! A `CachedThreadLocal` type is also provided which wraps a `ThreadLocal` but +//! also uses a special fast path for the first thread that writes into it. The +//! fast path has very low overhead (<1ns per access) while keeping the same +//! performance as `ThreadLocal` for other threads. +//! +//! Note that since thread IDs are recycled when a thread exits, it is possible +//! for one thread to retrieve the object of another thread. Since this can only +//! occur after a thread has exited this does not lead to any race conditions. +//! +//! # Examples +//! +//! Basic usage of `ThreadLocal`: +//! +//! ```rust +//! use thread_local::ThreadLocal; +//! let tls: ThreadLocal<u32> = ThreadLocal::new(); +//! assert_eq!(tls.get(), None); +//! assert_eq!(tls.get_or(|| 5), &5); +//! assert_eq!(tls.get(), Some(&5)); +//! ``` +//! +//! Combining thread-local values into a single result: +//! +//! ```rust +//! use thread_local::ThreadLocal; +//! use std::sync::Arc; +//! use std::cell::Cell; +//! use std::thread; +//! +//! let tls = Arc::new(ThreadLocal::new()); +//! +//! // Create a bunch of threads to do stuff +//! for _ in 0..5 { +//! let tls2 = tls.clone(); +//! thread::spawn(move || { +//! // Increment a counter to count some event... +//! let cell = tls2.get_or(|| Cell::new(0)); +//! cell.set(cell.get() + 1); +//! }).join().unwrap(); +//! } +//! +//! // Once all threads are done, collect the counter values and return the +//! // sum of all thread-local counter values. +//! let tls = Arc::try_unwrap(tls).unwrap(); +//! let total = tls.into_iter().fold(0, |x, y| x + y.get()); +//! assert_eq!(total, 5); +//! ``` + +#![warn(missing_docs)] + +#[macro_use] +extern crate lazy_static; + +mod thread_id; +mod unreachable; +mod cached; + +pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; + +use std::cell::UnsafeCell; +use std::fmt; +use std::marker::PhantomData; +use std::panic::UnwindSafe; +use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering}; +use std::sync::Mutex; +use unreachable::{UncheckedOptionExt, UncheckedResultExt}; + +/// Thread-local variable wrapper +/// +/// See the [module-level documentation](index.html) for more. +pub struct ThreadLocal<T: Send> { + // Pointer to the current top-level hash table + table: AtomicPtr<Table<T>>, + + // Lock used to guard against concurrent modifications. This is only taken + // while writing to the table, not when reading from it. This also guards + // the counter for the total number of values in the hash table. + lock: Mutex<usize>, +} + +struct Table<T: Send> { + // Hash entries for the table + entries: Box<[TableEntry<T>]>, + + // Number of bits used for the hash function + hash_bits: usize, + + // Previous table, half the size of the current one + prev: Option<Box<Table<T>>>, +} + +struct TableEntry<T: Send> { + // Current owner of this entry, or 0 if this is an empty entry + owner: AtomicUsize, + + // The object associated with this entry. This is only ever accessed by the + // owner of the entry. + data: UnsafeCell<Option<Box<T>>>, +} + +// ThreadLocal is always Sync, even if T isn't +unsafe impl<T: Send> Sync for ThreadLocal<T> {} + +impl<T: Send> Default for ThreadLocal<T> { + fn default() -> ThreadLocal<T> { + ThreadLocal::new() + } +} + +impl<T: Send> Drop for ThreadLocal<T> { + fn drop(&mut self) { + unsafe { + Box::from_raw(self.table.load(Ordering::Relaxed)); + } + } +} + +// Implementation of Clone for TableEntry, needed to make vec![] work +impl<T: Send> Clone for TableEntry<T> { + fn clone(&self) -> TableEntry<T> { + TableEntry { + owner: AtomicUsize::new(0), + data: UnsafeCell::new(None), + } + } +} + +// Hash function for the thread id +#[cfg(target_pointer_width = "32")] +#[inline] +fn hash(id: usize, bits: usize) -> usize { + id.wrapping_mul(0x9E3779B9) >> (32 - bits) +} +#[cfg(target_pointer_width = "64")] +#[inline] +fn hash(id: usize, bits: usize) -> usize { + id.wrapping_mul(0x9E37_79B9_7F4A_7C15) >> (64 - bits) +} + +impl<T: Send> ThreadLocal<T> { + /// Creates a new empty `ThreadLocal`. + pub fn new() -> ThreadLocal<T> { + let entry = TableEntry { + owner: AtomicUsize::new(0), + data: UnsafeCell::new(None), + }; + let table = Table { + entries: vec![entry; 2].into_boxed_slice(), + hash_bits: 1, + prev: None, + }; + ThreadLocal { + table: AtomicPtr::new(Box::into_raw(Box::new(table))), + lock: Mutex::new(0), + } + } + + /// Returns the element for the current thread, if it exists. + pub fn get(&self) -> Option<&T> { + let id = thread_id::get(); + self.get_fast(id) + } + + /// Returns the element for the current thread, or creates it if it doesn't + /// exist. + pub fn get_or<F>(&self, create: F) -> &T + where + F: FnOnce() -> T, + { + unsafe { + self.get_or_try(|| Ok::<T, ()>(create())) + .unchecked_unwrap_ok() + } + } + + /// Returns the element for the current thread, or creates it if it doesn't + /// exist. If `create` fails, that error is returned and no element is + /// added. + pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E> + where + F: FnOnce() -> Result<T, E>, + { + let id = thread_id::get(); + match self.get_fast(id) { + Some(x) => Ok(x), + None => Ok(self.insert(id, Box::new(create()?), true)), + } + } + + // Simple hash table lookup function + fn lookup(id: usize, table: &Table<T>) -> Option<&UnsafeCell<Option<Box<T>>>> { + // Because we use a Mutex to prevent concurrent modifications (but not + // reads) of the hash table, we can avoid any memory barriers here. No + // elements between our hash bucket and our value can have been modified + // since we inserted our thread-local value into the table. + for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) { + let owner = entry.owner.load(Ordering::Relaxed); + if owner == id { + return Some(&entry.data); + } + if owner == 0 { + return None; + } + } + unreachable!(); + } + + // Fast path: try to find our thread in the top-level hash table + fn get_fast(&self, id: usize) -> Option<&T> { + let table = unsafe { &*self.table.load(Ordering::Acquire) }; + match Self::lookup(id, table) { + Some(x) => unsafe { Some((*x.get()).as_ref().unchecked_unwrap()) }, + None => self.get_slow(id, table), + } + } + + // Slow path: try to find our thread in the other hash tables, and then + // move it to the top-level hash table. + #[cold] + fn get_slow(&self, id: usize, table_top: &Table<T>) -> Option<&T> { + let mut current = &table_top.prev; + while let Some(ref table) = *current { + if let Some(x) = Self::lookup(id, table) { + let data = unsafe { (*x.get()).take().unchecked_unwrap() }; + return Some(self.insert(id, data, false)); + } + current = &table.prev; + } + None + } + + #[cold] + fn insert(&self, id: usize, data: Box<T>, new: bool) -> &T { + // Lock the Mutex to ensure only a single thread is modify the hash + // table at once. + let mut count = self.lock.lock().unwrap(); + if new { + *count += 1; + } + let table_raw = self.table.load(Ordering::Relaxed); + let table = unsafe { &*table_raw }; + + // If the current top-level hash table is more than 75% full, add a new + // level with 2x the capacity. Elements will be moved up to the new top + // level table as they are accessed. + let table = if *count > table.entries.len() * 3 / 4 { + let entry = TableEntry { + owner: AtomicUsize::new(0), + data: UnsafeCell::new(None), + }; + let new_table = Box::into_raw(Box::new(Table { + entries: vec![entry; table.entries.len() * 2].into_boxed_slice(), + hash_bits: table.hash_bits + 1, + prev: unsafe { Some(Box::from_raw(table_raw)) }, + })); + self.table.store(new_table, Ordering::Release); + unsafe { &*new_table } + } else { + table + }; + + // Insert the new element into the top-level hash table + for entry in table.entries.iter().cycle().skip(hash(id, table.hash_bits)) { + let owner = entry.owner.load(Ordering::Relaxed); + if owner == 0 { + unsafe { + entry.owner.store(id, Ordering::Relaxed); + *entry.data.get() = Some(data); + return (*entry.data.get()).as_ref().unchecked_unwrap(); + } + } + if owner == id { + // This can happen if create() inserted a value into this + // ThreadLocal between our calls to get_fast() and insert(). We + // just return the existing value and drop the newly-allocated + // Box. + unsafe { + return (*entry.data.get()).as_ref().unchecked_unwrap(); + } + } + } + unreachable!(); + } + + fn raw_iter(&mut self) -> RawIter<T> { + RawIter { + remaining: *self.lock.get_mut().unwrap(), + index: 0, + table: self.table.load(Ordering::Relaxed), + } + } + + /// Returns a mutable iterator over the local values of all threads. + /// + /// Since this call borrows the `ThreadLocal` mutably, this operation can + /// be done safely---the mutable borrow statically guarantees no other + /// threads are currently accessing their associated values. + pub fn iter_mut(&mut self) -> IterMut<T> { + IterMut { + raw: self.raw_iter(), + marker: PhantomData, + } + } + + /// Removes all thread-specific values from the `ThreadLocal`, effectively + /// reseting it to its original state. + /// + /// Since this call borrows the `ThreadLocal` mutably, this operation can + /// be done safely---the mutable borrow statically guarantees no other + /// threads are currently accessing their associated values. + pub fn clear(&mut self) { + *self = ThreadLocal::new(); + } +} + +impl<T: Send> IntoIterator for ThreadLocal<T> { + type Item = T; + type IntoIter = IntoIter<T>; + + fn into_iter(mut self) -> IntoIter<T> { + IntoIter { + raw: self.raw_iter(), + _thread_local: self, + } + } +} + +impl<'a, T: Send + 'a> IntoIterator for &'a mut ThreadLocal<T> { + type Item = &'a mut T; + type IntoIter = IterMut<'a, T>; + + fn into_iter(self) -> IterMut<'a, T> { + self.iter_mut() + } +} + +impl<T: Send + Default> ThreadLocal<T> { + /// Returns the element for the current thread, or creates a default one if + /// it doesn't exist. + pub fn get_or_default(&self) -> &T { + self.get_or(Default::default) + } +} + +impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get()) + } +} + +impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {} + +struct RawIter<T: Send> { + remaining: usize, + index: usize, + table: *const Table<T>, +} + +impl<T: Send> Iterator for RawIter<T> { + type Item = *mut Option<Box<T>>; + + fn next(&mut self) -> Option<*mut Option<Box<T>>> { + if self.remaining == 0 { + return None; + } + + loop { + let entries = unsafe { &(*self.table).entries[..] }; + while self.index < entries.len() { + let val = entries[self.index].data.get(); + self.index += 1; + if unsafe { (*val).is_some() } { + self.remaining -= 1; + return Some(val); + } + } + self.index = 0; + self.table = unsafe { &**(*self.table).prev.as_ref().unchecked_unwrap() }; + } + } + + fn size_hint(&self) -> (usize, Option<usize>) { + (self.remaining, Some(self.remaining)) + } +} + +/// Mutable iterator over the contents of a `ThreadLocal`. +pub struct IterMut<'a, T: Send + 'a> { + raw: RawIter<T>, + marker: PhantomData<&'a mut ThreadLocal<T>>, +} + +impl<'a, T: Send + 'a> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option<&'a mut T> { + self.raw + .next() + .map(|x| unsafe { &mut **(*x).as_mut().unchecked_unwrap() }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.raw.size_hint() + } +} + +impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {} + +/// An iterator that moves out of a `ThreadLocal`. +pub struct IntoIter<T: Send> { + raw: RawIter<T>, + _thread_local: ThreadLocal<T>, +} + +impl<T: Send> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.raw + .next() + .map(|x| unsafe { *(*x).take().unchecked_unwrap() }) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.raw.size_hint() + } +} + +impl<T: Send> ExactSizeIterator for IntoIter<T> {} + +#[cfg(test)] +mod tests { + use super::{CachedThreadLocal, ThreadLocal}; + use std::cell::RefCell; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering::Relaxed; + use std::sync::Arc; + use std::thread; + + fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> { + let count = AtomicUsize::new(0); + Arc::new(move || count.fetch_add(1, Relaxed)) + } + + #[test] + fn same_thread() { + let create = make_create(); + let mut tls = ThreadLocal::new(); + assert_eq!(None, tls.get()); + assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); + tls.clear(); + assert_eq!(None, tls.get()); + } + + #[test] + fn same_thread_cached() { + let create = make_create(); + let mut tls = CachedThreadLocal::new(); + assert_eq!(None, tls.get()); + assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls)); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls)); + tls.clear(); + assert_eq!(None, tls.get()); + } + + #[test] + fn different_thread() { + let create = make_create(); + let tls = Arc::new(ThreadLocal::new()); + assert_eq!(None, tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + + let tls2 = tls.clone(); + let create2 = create.clone(); + thread::spawn(move || { + assert_eq!(None, tls2.get()); + assert_eq!(1, *tls2.get_or(|| create2())); + assert_eq!(Some(&1), tls2.get()); + }) + .join() + .unwrap(); + + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + } + + #[test] + fn different_thread_cached() { + let create = make_create(); + let tls = Arc::new(CachedThreadLocal::new()); + assert_eq!(None, tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + assert_eq!(Some(&0), tls.get()); + + let tls2 = tls.clone(); + let create2 = create.clone(); + thread::spawn(move || { + assert_eq!(None, tls2.get()); + assert_eq!(1, *tls2.get_or(|| create2())); + assert_eq!(Some(&1), tls2.get()); + }) + .join() + .unwrap(); + + assert_eq!(Some(&0), tls.get()); + assert_eq!(0, *tls.get_or(|| create())); + } + + #[test] + fn iter() { + let tls = Arc::new(ThreadLocal::new()); + tls.get_or(|| Box::new(1)); + + let tls2 = tls.clone(); + thread::spawn(move || { + tls2.get_or(|| Box::new(2)); + let tls3 = tls2.clone(); + thread::spawn(move || { + tls3.get_or(|| Box::new(3)); + }) + .join() + .unwrap(); + }) + .join() + .unwrap(); + + let mut tls = Arc::try_unwrap(tls).unwrap(); + let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>(); + v.sort(); + assert_eq!(vec![1, 2, 3], v); + let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>(); + v.sort(); + assert_eq!(vec![1, 2, 3], v); + } + + #[test] + fn iter_cached() { + let tls = Arc::new(CachedThreadLocal::new()); + tls.get_or(|| Box::new(1)); + + let tls2 = tls.clone(); + thread::spawn(move || { + tls2.get_or(|| Box::new(2)); + let tls3 = tls2.clone(); + thread::spawn(move || { + tls3.get_or(|| Box::new(3)); + }) + .join() + .unwrap(); + }) + .join() + .unwrap(); + + let mut tls = Arc::try_unwrap(tls).unwrap(); + let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>(); + v.sort(); + assert_eq!(vec![1, 2, 3], v); + let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>(); + v.sort(); + assert_eq!(vec![1, 2, 3], v); + } + + #[test] + fn is_sync() { + fn foo<T: Sync>() {} + foo::<ThreadLocal<String>>(); + foo::<ThreadLocal<RefCell<String>>>(); + foo::<CachedThreadLocal<String>>(); + foo::<CachedThreadLocal<RefCell<String>>>(); + } +} |