From c37c85bc7a34f7f632aa8776e15c0d1d387b78e2 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Tue, 9 Feb 2021 18:14:19 -0800 Subject: Upgrade rust/crates/thread_local to 1.1.3 Test: make Change-Id: I4e23358c5912509a3598e502aa0d92791e71e435 --- src/lib.rs | 284 ++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 207 insertions(+), 77 deletions(-) (limited to 'src/lib.rs') diff --git a/src/lib.rs b/src/lib.rs index 78bdcc3..f26f6ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,9 +66,6 @@ #![warn(missing_docs)] #![allow(clippy::mutex_atomic)] -#[macro_use] -extern crate lazy_static; - mod cached; mod thread_id; mod unreachable; @@ -78,14 +75,15 @@ pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal}; use std::cell::UnsafeCell; use std::fmt; -use std::marker::PhantomData; +use std::iter::FusedIterator; use std::mem; +use std::mem::MaybeUninit; use std::panic::UnwindSafe; use std::ptr; -use std::sync::atomic::{AtomicPtr, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; use std::sync::Mutex; use thread_id::Thread; -use unreachable::{UncheckedOptionExt, UncheckedResultExt}; +use unreachable::UncheckedResultExt; // Use usize::BITS once it has stabilized and the MSRV has been bumped. #[cfg(target_pointer_width = "16")] @@ -104,13 +102,31 @@ const BUCKETS: usize = (POINTER_WIDTH + 1) as usize; pub struct ThreadLocal { /// The buckets in the thread local. The nth bucket contains `2^(n-1)` /// elements. Each bucket is lazily allocated. - buckets: [AtomicPtr>>; BUCKETS], + buckets: [AtomicPtr>; BUCKETS], + + /// The number of values in the thread local. This can be less than the real number of values, + /// but is never more. + values: AtomicUsize, /// Lock used to guard against concurrent modifications. This is taken when /// there is a possibility of allocating a new bucket, which only occurs - /// when inserting values. This also guards the counter for the total number - /// of values in the thread local. - lock: Mutex, + /// when inserting values. + lock: Mutex<()>, +} + +struct Entry { + present: AtomicBool, + value: UnsafeCell>, +} + +impl Drop for Entry { + fn drop(&mut self) { + unsafe { + if *self.present.get_mut() { + ptr::drop_in_place((*self.value.get()).as_mut_ptr()); + } + } + } } // ThreadLocal is always Sync, even if T isn't @@ -173,7 +189,8 @@ impl ThreadLocal { // Safety: AtomicPtr has the same representation as a pointer and arrays have the same // representation as a sequence of their inner type. buckets: unsafe { mem::transmute(buckets) }, - lock: Mutex::new(0), + values: AtomicUsize::new(0), + lock: Mutex::new(()), } } @@ -215,14 +232,21 @@ impl ThreadLocal { if bucket_ptr.is_null() { return None; } - unsafe { (&*(&*bucket_ptr.add(thread.index)).get()).as_ref() } + unsafe { + let entry = &*bucket_ptr.add(thread.index); + // Read without atomic operations as only this thread can set the value. + if (&entry.present as *const _ as *const bool).read() { + Some(&*(&*entry.value.get()).as_ptr()) + } else { + None + } + } } #[cold] fn insert(&self, thread: Thread, data: T) -> &T { // Lock the Mutex to ensure only a single thread is allocating buckets at once - let mut count = self.lock.lock().unwrap(); - *count += 1; + let _guard = self.lock.lock().unwrap(); let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; @@ -236,25 +260,30 @@ impl ThreadLocal { bucket_ptr }; - drop(count); + drop(_guard); // Insert the new element into the bucket - unsafe { - let value_ptr = (&*bucket_ptr.add(thread.index)).get(); - *value_ptr = Some(data); - (&*value_ptr).as_ref().unchecked_unwrap() - } + let entry = unsafe { &*bucket_ptr.add(thread.index) }; + let value_ptr = entry.value.get(); + unsafe { value_ptr.write(MaybeUninit::new(data)) }; + entry.present.store(true, Ordering::Release); + + self.values.fetch_add(1, Ordering::Release); + + unsafe { &*(&*value_ptr).as_ptr() } } - fn raw_iter(&mut self) -> RawIter { - RawIter { - remaining: *self.lock.get_mut().unwrap(), - buckets: unsafe { - *(&self.buckets as *const _ as *const [*const UnsafeCell>; BUCKETS]) - }, - bucket: 0, - bucket_size: 1, - index: 0, + /// Returns an iterator over the local values of all threads in unspecified + /// order. + /// + /// This call can be done safely, as `T` is required to implement [`Sync`]. + pub fn iter(&self) -> Iter<'_, T> + where + T: Sync, + { + Iter { + thread_local: self, + raw: RawIter::new(), } } @@ -266,8 +295,8 @@ impl ThreadLocal { /// threads are currently accessing their associated values. pub fn iter_mut(&mut self) -> IterMut { IterMut { - raw: self.raw_iter(), - marker: PhantomData, + thread_local: self, + raw: RawIter::new(), } } @@ -286,15 +315,24 @@ impl IntoIterator for ThreadLocal { type Item = T; type IntoIter = IntoIter; - fn into_iter(mut self) -> IntoIter { + fn into_iter(self) -> IntoIter { IntoIter { - raw: self.raw_iter(), - _thread_local: self, + thread_local: self, + raw: RawIter::new(), } } } -impl<'a, T: Send + 'a> IntoIterator for &'a mut ThreadLocal { +impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal { + type Item = &'a T; + type IntoIter = Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal { type Item = &'a mut T; type IntoIter = IterMut<'a, T>; @@ -319,102 +357,171 @@ impl fmt::Debug for ThreadLocal { impl UnwindSafe for ThreadLocal {} -struct RawIter { - remaining: usize, - buckets: [*const UnsafeCell>; BUCKETS], +#[derive(Debug)] +struct RawIter { + yielded: usize, bucket: usize, bucket_size: usize, index: usize, } +impl RawIter { + #[inline] + fn new() -> Self { + Self { + yielded: 0, + bucket: 0, + bucket_size: 1, + index: 0, + } + } -impl Iterator for RawIter { - type Item = *mut Option; + fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal) -> Option<&'a T> { + while self.bucket < BUCKETS { + let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) }; + let bucket = bucket.load(Ordering::Relaxed); - fn next(&mut self) -> Option { - if self.remaining == 0 { + if !bucket.is_null() { + while self.index < self.bucket_size { + let entry = unsafe { &*bucket.add(self.index) }; + self.index += 1; + if entry.present.load(Ordering::Acquire) { + self.yielded += 1; + return Some(unsafe { &*(&*entry.value.get()).as_ptr() }); + } + } + } + + self.next_bucket(); + } + None + } + fn next_mut<'a, T: Send>( + &mut self, + thread_local: &'a mut ThreadLocal, + ) -> Option<&'a mut Entry> { + if *thread_local.values.get_mut() == self.yielded { return None; } loop { - let bucket = unsafe { *self.buckets.get_unchecked(self.bucket) }; + let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) }; + let bucket = *bucket.get_mut(); if !bucket.is_null() { while self.index < self.bucket_size { - let item = unsafe { (&*bucket.add(self.index)).get() }; - + let entry = unsafe { &mut *bucket.add(self.index) }; self.index += 1; - - if unsafe { &*item }.is_some() { - self.remaining -= 1; - return Some(item); + if *entry.present.get_mut() { + self.yielded += 1; + return Some(entry); } } } - if self.bucket != 0 { - self.bucket_size <<= 1; - } - self.bucket += 1; + self.next_bucket(); + } + } - self.index = 0; + #[inline] + fn next_bucket(&mut self) { + if self.bucket != 0 { + self.bucket_size <<= 1; } + self.bucket += 1; + self.index = 0; + } + + fn size_hint(&self, thread_local: &ThreadLocal) -> (usize, Option) { + let total = thread_local.values.load(Ordering::Acquire); + (total - self.yielded, None) + } + fn size_hint_frozen(&self, thread_local: &ThreadLocal) -> (usize, Option) { + let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) }; + let remaining = total - self.yielded; + (remaining, Some(remaining)) } +} + +/// Iterator over the contents of a `ThreadLocal`. +#[derive(Debug)] +pub struct Iter<'a, T: Send + Sync> { + thread_local: &'a ThreadLocal, + raw: RawIter, +} +impl<'a, T: Send + Sync> Iterator for Iter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option { + self.raw.next(self.thread_local) + } fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) + self.raw.size_hint(self.thread_local) } } +impl FusedIterator for Iter<'_, T> {} /// Mutable iterator over the contents of a `ThreadLocal`. -pub struct IterMut<'a, T: Send + 'a> { - raw: RawIter, - marker: PhantomData<&'a mut ThreadLocal>, +pub struct IterMut<'a, T: Send> { + thread_local: &'a mut ThreadLocal, + raw: RawIter, } -impl<'a, T: Send + 'a> Iterator for IterMut<'a, T> { +impl<'a, T: Send> Iterator for IterMut<'a, T> { type Item = &'a mut T; - fn next(&mut self) -> Option<&'a mut T> { self.raw - .next() - .map(|x| unsafe { &mut *(*x).as_mut().unchecked_unwrap() }) + .next_mut(self.thread_local) + .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() }) } - fn size_hint(&self) -> (usize, Option) { - self.raw.size_hint() + self.raw.size_hint_frozen(self.thread_local) } } -impl<'a, T: Send + 'a> ExactSizeIterator for IterMut<'a, T> {} +impl ExactSizeIterator for IterMut<'_, T> {} +impl FusedIterator for IterMut<'_, T> {} + +// Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to +// this thread's value that potentially aliases with a mutable reference we have given out. +impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IterMut").field("raw", &self.raw).finish() + } +} /// An iterator that moves out of a `ThreadLocal`. +#[derive(Debug)] pub struct IntoIter { - raw: RawIter, - _thread_local: ThreadLocal, + thread_local: ThreadLocal, + raw: RawIter, } impl Iterator for IntoIter { type Item = T; - fn next(&mut self) -> Option { - self.raw - .next() - .map(|x| unsafe { (*x).take().unchecked_unwrap() }) + self.raw.next_mut(&mut self.thread_local).map(|entry| { + *entry.present.get_mut() = false; + unsafe { + std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() + } + }) } - fn size_hint(&self) -> (usize, Option) { - self.raw.size_hint() + self.raw.size_hint_frozen(&self.thread_local) } } impl ExactSizeIterator for IntoIter {} +impl FusedIterator for IntoIter {} -fn allocate_bucket(size: usize) -> *mut UnsafeCell> { +fn allocate_bucket(size: usize) -> *mut Entry { Box::into_raw( (0..size) - .map(|_| UnsafeCell::new(None::)) - .collect::>() - .into_boxed_slice(), + .map(|_| Entry:: { + present: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + }) + .collect(), ) as *mut _ } @@ -491,14 +598,37 @@ mod tests { .unwrap(); let mut tls = Arc::try_unwrap(tls).unwrap(); + + let mut v = tls.iter().map(|x| **x).collect::>(); + v.sort_unstable(); + assert_eq!(vec![1, 2, 3], v); + let mut v = tls.iter_mut().map(|x| **x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v); + let mut v = tls.into_iter().map(|x| *x).collect::>(); v.sort_unstable(); assert_eq!(vec![1, 2, 3], v); } + #[test] + fn test_drop() { + let local = ThreadLocal::new(); + struct Dropped(Arc); + impl Drop for Dropped { + fn drop(&mut self) { + self.0.fetch_add(1, Relaxed); + } + } + + let dropped = Arc::new(AtomicUsize::new(0)); + local.get_or(|| Dropped(dropped.clone())); + assert_eq!(dropped.load(Relaxed), 0); + drop(local); + assert_eq!(dropped.load(Relaxed), 1); + } + #[test] fn is_sync() { fn foo() {} -- cgit v1.2.3