diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/errors.rs | 14 | ||||
-rw-r--r-- | src/fragile.rs | 326 | ||||
-rw-r--r-- | src/lib.rs | 157 | ||||
-rw-r--r-- | src/registry.rs | 104 | ||||
-rw-r--r-- | src/semisticky.rs | 339 | ||||
-rw-r--r-- | src/sticky.rs | 423 | ||||
-rw-r--r-- | src/thread_id.rs | 12 |
7 files changed, 1375 insertions, 0 deletions
diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..bf3fb4c --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,14 @@ +use std::error; +use std::fmt; + +/// Returned when borrowing fails. +#[derive(Debug)] +pub struct InvalidThreadAccess; + +impl fmt::Display for InvalidThreadAccess { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "fragile value accessed from foreign thread") + } +} + +impl error::Error for InvalidThreadAccess {} diff --git a/src/fragile.rs b/src/fragile.rs new file mode 100644 index 0000000..92eb3d3 --- /dev/null +++ b/src/fragile.rs @@ -0,0 +1,326 @@ +use std::cmp; +use std::fmt; +use std::mem; +use std::num::NonZeroUsize; + +use crate::errors::InvalidThreadAccess; +use crate::thread_id; +use std::mem::ManuallyDrop; + +/// A [`Fragile<T>`] wraps a non sendable `T` to be safely send to other threads. +/// +/// Once the value has been wrapped it can be sent to other threads but access +/// to the value on those threads will fail. +/// +/// If the value needs destruction and the fragile wrapper is on another thread +/// the destructor will panic. Alternatively you can use +/// [`Sticky`](crate::Sticky) which is not going to panic but might temporarily +/// leak the value. +pub struct Fragile<T> { + // ManuallyDrop is necessary because we need to move out of here without running the + // Drop code in functions like `into_inner`. + value: ManuallyDrop<T>, + thread_id: NonZeroUsize, +} + +impl<T> Fragile<T> { + /// Creates a new [`Fragile`] wrapping a `value`. + /// + /// The value that is moved into the [`Fragile`] can be non `Send` and + /// will be anchored to the thread that created the object. If the + /// fragile wrapper type ends up being send from thread to thread + /// only the original thread can interact with the value. + pub fn new(value: T) -> Self { + Fragile { + value: ManuallyDrop::new(value), + thread_id: thread_id::get(), + } + } + + /// Returns `true` if the access is valid. + /// + /// This will be `false` if the value was sent to another thread. + pub fn is_valid(&self) -> bool { + thread_id::get() == self.thread_id + } + + #[inline(always)] + fn assert_thread(&self) { + if !self.is_valid() { + panic!("trying to access wrapped value in fragile container from incorrect thread."); + } + } + + /// Consumes the `Fragile`, returning the wrapped value. + /// + /// # Panics + /// + /// Panics if called from a different thread than the one where the + /// original value was created. + pub fn into_inner(self) -> T { + self.assert_thread(); + + let mut this = ManuallyDrop::new(self); + + // SAFETY: `this` is not accessed beyond this point, and because it's in a ManuallyDrop its + // destructor is not run. + unsafe { ManuallyDrop::take(&mut this.value) } + } + + /// Consumes the `Fragile`, returning the wrapped value if successful. + /// + /// The wrapped value is returned if this is called from the same thread + /// as the one where the original value was created, otherwise the + /// [`Fragile`] is returned as `Err(self)`. + pub fn try_into_inner(self) -> Result<T, Self> { + if thread_id::get() == self.thread_id { + Ok(self.into_inner()) + } else { + Err(self) + } + } + + /// Immutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get`](Self::try_get). + pub fn get(&self) -> &T { + self.assert_thread(); + &*self.value + } + + /// Mutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut). + pub fn get_mut(&mut self) -> &mut T { + self.assert_thread(); + &mut *self.value + } + + /// Tries to immutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> { + if thread_id::get() == self.thread_id { + Ok(&*self.value) + } else { + Err(InvalidThreadAccess) + } + } + + /// Tries to mutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> { + if thread_id::get() == self.thread_id { + Ok(&mut *self.value) + } else { + Err(InvalidThreadAccess) + } + } +} + +impl<T> Drop for Fragile<T> { + fn drop(&mut self) { + if mem::needs_drop::<T>() { + if thread_id::get() == self.thread_id { + // SAFETY: `ManuallyDrop::drop` cannot be called after this point. + unsafe { ManuallyDrop::drop(&mut self.value) }; + } else { + panic!("destructor of fragile object ran on wrong thread"); + } + } + } +} + +impl<T> From<T> for Fragile<T> { + #[inline] + fn from(t: T) -> Fragile<T> { + Fragile::new(t) + } +} + +impl<T: Clone> Clone for Fragile<T> { + #[inline] + fn clone(&self) -> Fragile<T> { + Fragile::new(self.get().clone()) + } +} + +impl<T: Default> Default for Fragile<T> { + #[inline] + fn default() -> Fragile<T> { + Fragile::new(T::default()) + } +} + +impl<T: PartialEq> PartialEq for Fragile<T> { + #[inline] + fn eq(&self, other: &Fragile<T>) -> bool { + *self.get() == *other.get() + } +} + +impl<T: Eq> Eq for Fragile<T> {} + +impl<T: PartialOrd> PartialOrd for Fragile<T> { + #[inline] + fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> { + self.get().partial_cmp(other.get()) + } + + #[inline] + fn lt(&self, other: &Fragile<T>) -> bool { + *self.get() < *other.get() + } + + #[inline] + fn le(&self, other: &Fragile<T>) -> bool { + *self.get() <= *other.get() + } + + #[inline] + fn gt(&self, other: &Fragile<T>) -> bool { + *self.get() > *other.get() + } + + #[inline] + fn ge(&self, other: &Fragile<T>) -> bool { + *self.get() >= *other.get() + } +} + +impl<T: Ord> Ord for Fragile<T> { + #[inline] + fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering { + self.get().cmp(other.get()) + } +} + +impl<T: fmt::Display> fmt::Display for Fragile<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + fmt::Display::fmt(self.get(), f) + } +} + +impl<T: fmt::Debug> fmt::Debug for Fragile<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match self.try_get() { + Ok(value) => f.debug_struct("Fragile").field("value", value).finish(), + Err(..) => { + struct InvalidPlaceholder; + impl fmt::Debug for InvalidPlaceholder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("<invalid thread>") + } + } + + f.debug_struct("Fragile") + .field("value", &InvalidPlaceholder) + .finish() + } + } + } +} + +// this type is sync because access can only ever happy from the same thread +// that created it originally. All other threads will be able to safely +// call some basic operations on the reference and they will fail. +unsafe impl<T> Sync for Fragile<T> {} + +// The entire point of this type is to be Send +#[allow(clippy::non_send_fields_in_send_ty)] +unsafe impl<T> Send for Fragile<T> {} + +#[test] +fn test_basic() { + use std::thread; + let val = Fragile::new(true); + assert_eq!(val.to_string(), "true"); + assert_eq!(val.get(), &true); + assert!(val.try_get().is_ok()); + thread::spawn(move || { + assert!(val.try_get().is_err()); + }) + .join() + .unwrap(); +} + +#[test] +fn test_mut() { + let mut val = Fragile::new(true); + *val.get_mut() = false; + assert_eq!(val.to_string(), "false"); + assert_eq!(val.get(), &false); +} + +#[test] +#[should_panic] +fn test_access_other_thread() { + use std::thread; + let val = Fragile::new(true); + thread::spawn(move || { + val.get(); + }) + .join() + .unwrap(); +} + +#[test] +fn test_noop_drop_elsewhere() { + use std::thread; + let val = Fragile::new(true); + thread::spawn(move || { + // force the move + val.try_get().ok(); + }) + .join() + .unwrap(); +} + +#[test] +fn test_panic_on_drop_elsewhere() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + let was_called = Arc::new(AtomicBool::new(false)); + struct X(Arc<AtomicBool>); + impl Drop for X { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + let val = Fragile::new(X(was_called.clone())); + assert!(thread::spawn(move || { + val.try_get().ok(); + }) + .join() + .is_err()); + assert!(!was_called.load(Ordering::SeqCst)); +} + +#[test] +fn test_rc_sending() { + use std::rc::Rc; + use std::sync::mpsc::channel; + use std::thread; + + let val = Fragile::new(Rc::new(true)); + let (tx, rx) = channel(); + + let thread = thread::spawn(move || { + assert!(val.try_get().is_err()); + let here = val; + tx.send(here).unwrap(); + }); + + let rv = rx.recv().unwrap(); + assert!(**rv.get()); + + thread.join().unwrap(); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..16edc9d --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,157 @@ +//! This library provides wrapper types that permit sending non `Send` types to +//! other threads and use runtime checks to ensure safety. +//! +//! It provides three types: [`Fragile`] and [`Sticky`] which are similar in nature +//! but have different behaviors with regards to how destructors are executed and +//! the extra [`SemiSticky`] type which uses [`Sticky`] if the value has a +//! destructor and [`Fragile`] if it does not. +//! +//! All three types wrap a value and provide a `Send` bound. Neither of the types permit +//! access to the enclosed value unless the thread that wrapped the value is attempting +//! to access it. The difference between the types starts playing a role once +//! destructors are involved. +//! +//! A [`Fragile`] will actually send the `T` from thread to thread but will only +//! permit the original thread to invoke the destructor. If the value gets dropped +//! in a different thread, the destructor will panic. +//! +//! A [`Sticky`] on the other hand does not actually send the `T` around but keeps +//! it stored in the original thread's thread local storage. If it gets dropped +//! in the originating thread it gets cleaned up immediately, otherwise it leaks +//! until the thread shuts down naturally. [`Sticky`] because it borrows into the +//! TLS also requires you to "prove" that you are not doing any funny business with +//! the borrowed value that lives for longer than the current stack frame which +//! results in a slightly more complex API. +//! +//! There is a third typed called [`SemiSticky`] which shares the API with [`Sticky`] +//! but internally uses a boxed [`Fragile`] if the type does not actually need a dtor +//! in which case [`Fragile`] is preferred. +//! +//! # Fragile Usage +//! +//! [`Fragile`] is the easiest type to use. It works almost like a cell. +//! +//! ``` +//! use std::thread; +//! use fragile::Fragile; +//! +//! // creating and using a fragile object in the same thread works +//! let val = Fragile::new(true); +//! assert_eq!(*val.get(), true); +//! assert!(val.try_get().is_ok()); +//! +//! // once send to another thread it stops working +//! thread::spawn(move || { +//! assert!(val.try_get().is_err()); +//! }).join() +//! .unwrap(); +//! ``` +//! +//! # Sticky Usage +//! +//! [`Sticky`] is similar to [`Fragile`] but because it places the value in the +//! thread local storage it comes with some extra restrictions to make it sound. +//! The advantage is it can be dropped from any thread but it comes with extra +//! restrictions. In particular it requires that values placed in it are `'static` +//! and that [`StackToken`]s are used to restrict lifetimes. +//! +//! ``` +//! use std::thread; +//! use fragile::Sticky; +//! +//! // creating and using a fragile object in the same thread works +//! fragile::stack_token!(tok); +//! let val = Sticky::new(true); +//! assert_eq!(*val.get(tok), true); +//! assert!(val.try_get(tok).is_ok()); +//! +//! // once send to another thread it stops working +//! thread::spawn(move || { +//! fragile::stack_token!(tok); +//! assert!(val.try_get(tok).is_err()); +//! }).join() +//! .unwrap(); +//! ``` +//! +//! # Why? +//! +//! Most of the time trying to use this crate is going to indicate some code smell. But +//! there are situations where this is useful. For instance you might have a bunch of +//! non `Send` types but want to work with a `Send` error type. In that case the non +//! sendable extra information can be contained within the error and in cases where the +//! error did not cross a thread boundary yet extra information can be obtained. +//! +//! # Drop / Cleanup Behavior +//! +//! All types will try to eagerly drop a value if they are dropped on the right thread. +//! [`Sticky`] and [`SemiSticky`] will however temporarily leak memory until a thread +//! shuts down if the value is dropped on the wrong thread. The benefit however is that +//! if you have that type of situation, and you can live with the consequences, the +//! type is not panicking. A [`Fragile`] dropped in the wrong thread will not just panic, +//! it will effectively also tear down the process because panicking in destructors is +//! non recoverable. +//! +//! # Features +//! +//! By default the crate has no dependencies. Optionally the `slab` feature can +//! be enabled which optimizes the internal storage of the [`Sticky`] type to +//! make it use a [`slab`](https://docs.rs/slab/latest/slab/) instead. +mod errors; +mod fragile; +mod registry; +mod semisticky; +mod sticky; +mod thread_id; + +use std::marker::PhantomData; + +pub use crate::errors::InvalidThreadAccess; +pub use crate::fragile::Fragile; +pub use crate::semisticky::SemiSticky; +pub use crate::sticky::Sticky; + +/// A token that is placed to the stack to constrain lifetimes. +/// +/// For more information about how these work see the documentation of +/// [`stack_token!`] which is the only way to create this token. +pub struct StackToken(PhantomData<*const ()>); + +impl StackToken { + /// Stack tokens must only be created on the stack. + #[doc(hidden)] + pub unsafe fn __private_new() -> StackToken { + // we place a const pointer in there to get a type + // that is neither Send nor Sync. + StackToken(PhantomData) + } +} + +/// Crates a token on the stack with a certain name for semi-sticky. +/// +/// The argument to the macro is the target name of a local variable +/// which holds a reference to a stack token. Because this is the +/// only way to create such a token, it acts as a proof to [`Sticky`] +/// or [`SemiSticky`] that can be used to constrain the lifetime of the +/// return values to the stack frame. +/// +/// This is necessary as otherwise a [`Sticky`] placed in a [`Box`] and +/// leaked with [`Box::leak`] (which creates a static lifetime) would +/// otherwise create a reference with `'static` lifetime. This is incorrect +/// as the actual lifetime is constrained to the lifetime of the thread. +/// For more information see [`issue 26`](https://github.com/mitsuhiko/fragile/issues/26). +/// +/// ```rust +/// let sticky = fragile::Sticky::new(true); +/// +/// // this places a token on the stack. +/// fragile::stack_token!(my_token); +/// +/// // the token needs to be passed to `get` and others. +/// let _ = sticky.get(my_token); +/// ``` +#[macro_export] +macro_rules! stack_token { + ($name:ident) => { + let $name = &unsafe { $crate::StackToken::__private_new() }; + }; +} diff --git a/src/registry.rs b/src/registry.rs new file mode 100644 index 0000000..1ee070d --- /dev/null +++ b/src/registry.rs @@ -0,0 +1,104 @@ +pub struct Entry { + /// The pointer to the object stored in the registry. This is a type-erased + /// `Box<T>`. + pub ptr: *mut (), + /// The function that can be called on the above pointer to drop the object + /// and free its allocation. + pub drop: unsafe fn(*mut ()), +} + +#[cfg(feature = "slab")] +mod slab_impl { + use std::cell::UnsafeCell; + use std::num::NonZeroUsize; + + use super::Entry; + + pub struct Registry(pub slab::Slab<Entry>); + + thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(slab::Slab::new()))); + + pub use usize as ItemId; + + pub fn insert(thread_id: NonZeroUsize, entry: Entry) -> ItemId { + let _ = thread_id; + REGISTRY.with(|registry| unsafe { (*registry.get()).0.insert(entry) }) + } + + pub fn with<R, F: FnOnce(&Entry) -> R>(item_id: ItemId, thread_id: NonZeroUsize, f: F) -> R { + let _ = thread_id; + REGISTRY.with(|registry| f(unsafe { &*registry.get() }.0.get(item_id).unwrap())) + } + + pub fn remove(item_id: ItemId, thread_id: NonZeroUsize) -> Entry { + let _ = thread_id; + REGISTRY.with(|registry| unsafe { (*registry.get()).0.remove(item_id) }) + } + + pub fn try_remove(item_id: ItemId, thread_id: NonZeroUsize) -> Option<Entry> { + let _ = thread_id; + REGISTRY.with(|registry| unsafe { (*registry.get()).0.try_remove(item_id) }) + } +} + +#[cfg(not(feature = "slab"))] +mod map_impl { + use std::cell::UnsafeCell; + use std::num::NonZeroUsize; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use super::Entry; + + pub struct Registry(pub std::collections::HashMap<(NonZeroUsize, NonZeroUsize), Entry>); + + thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default()))); + + pub type ItemId = NonZeroUsize; + + fn next_item_id() -> NonZeroUsize { + static COUNTER: AtomicUsize = AtomicUsize::new(1); + NonZeroUsize::new(COUNTER.fetch_add(1, Ordering::SeqCst)) + .expect("more than usize::MAX items") + } + + pub fn insert(thread_id: NonZeroUsize, entry: Entry) -> ItemId { + let item_id = next_item_id(); + REGISTRY + .with(|registry| unsafe { (*registry.get()).0.insert((thread_id, item_id), entry) }); + item_id + } + + pub fn with<R, F: FnOnce(&Entry) -> R>(item_id: ItemId, thread_id: NonZeroUsize, f: F) -> R { + REGISTRY.with(|registry| { + f(unsafe { &*registry.get() } + .0 + .get(&(thread_id, item_id)) + .unwrap()) + }) + } + + pub fn remove(item_id: ItemId, thread_id: NonZeroUsize) -> Entry { + REGISTRY + .with(|registry| unsafe { (*registry.get()).0.remove(&(thread_id, item_id)).unwrap() }) + } + + pub fn try_remove(item_id: ItemId, thread_id: NonZeroUsize) -> Option<Entry> { + REGISTRY.with(|registry| unsafe { (*registry.get()).0.remove(&(thread_id, item_id)) }) + } +} + +#[cfg(feature = "slab")] +pub use self::slab_impl::*; + +#[cfg(not(feature = "slab"))] +pub use self::map_impl::*; + +impl Drop for Registry { + fn drop(&mut self) { + for (_, value) in self.0.iter() { + // SAFETY: This function is only called once, and is called with the + // pointer it was created with. + unsafe { (value.drop)(value.ptr) }; + } + } +} diff --git a/src/semisticky.rs b/src/semisticky.rs new file mode 100644 index 0000000..2b6c0f4 --- /dev/null +++ b/src/semisticky.rs @@ -0,0 +1,339 @@ +use std::cmp; +use std::fmt; +use std::mem; + +use crate::errors::InvalidThreadAccess; +use crate::fragile::Fragile; +use crate::sticky::Sticky; +use crate::StackToken; + +enum SemiStickyImpl<T: 'static> { + Fragile(Box<Fragile<T>>), + Sticky(Sticky<T>), +} + +/// A [`SemiSticky<T>`] keeps a value T stored in a thread if it has a drop. +/// +/// This is a combined version of [`Fragile`] and [`Sticky`]. If the type +/// does not have a drop it will effectively be a [`Fragile`], otherwise it +/// will be internally behave like a [`Sticky`]. +/// +/// This type requires `T: 'static` for the same reasons as [`Sticky`] and +/// also uses [`StackToken`]s. +pub struct SemiSticky<T: 'static> { + inner: SemiStickyImpl<T>, +} + +impl<T> SemiSticky<T> { + /// Creates a new [`SemiSticky`] wrapping a `value`. + /// + /// The value that is moved into the `SemiSticky` can be non `Send` and + /// will be anchored to the thread that created the object. If the + /// sticky wrapper type ends up being send from thread to thread + /// only the original thread can interact with the value. In case the + /// value does not have `Drop` it will be stored in the [`SemiSticky`] + /// instead. + pub fn new(value: T) -> Self { + SemiSticky { + inner: if mem::needs_drop::<T>() { + SemiStickyImpl::Sticky(Sticky::new(value)) + } else { + SemiStickyImpl::Fragile(Box::new(Fragile::new(value))) + }, + } + } + + /// Returns `true` if the access is valid. + /// + /// This will be `false` if the value was sent to another thread. + pub fn is_valid(&self) -> bool { + match self.inner { + SemiStickyImpl::Fragile(ref inner) => inner.is_valid(), + SemiStickyImpl::Sticky(ref inner) => inner.is_valid(), + } + } + + /// Consumes the [`SemiSticky`], returning the wrapped value. + /// + /// # Panics + /// + /// Panics if called from a different thread than the one where the + /// original value was created. + pub fn into_inner(self) -> T { + match self.inner { + SemiStickyImpl::Fragile(inner) => inner.into_inner(), + SemiStickyImpl::Sticky(inner) => inner.into_inner(), + } + } + + /// Consumes the [`SemiSticky`], returning the wrapped value if successful. + /// + /// The wrapped value is returned if this is called from the same thread + /// as the one where the original value was created, otherwise the + /// [`SemiSticky`] is returned as `Err(self)`. + pub fn try_into_inner(self) -> Result<T, Self> { + match self.inner { + SemiStickyImpl::Fragile(inner) => inner.try_into_inner().map_err(|inner| SemiSticky { + inner: SemiStickyImpl::Fragile(Box::new(inner)), + }), + SemiStickyImpl::Sticky(inner) => inner.try_into_inner().map_err(|inner| SemiSticky { + inner: SemiStickyImpl::Sticky(inner), + }), + } + } + + /// Immutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get`](Self::try_get). + pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T { + match self.inner { + SemiStickyImpl::Fragile(ref inner) => inner.get(), + SemiStickyImpl::Sticky(ref inner) => inner.get(_proof), + } + } + + /// Mutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut). + pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T { + match self.inner { + SemiStickyImpl::Fragile(ref mut inner) => inner.get_mut(), + SemiStickyImpl::Sticky(ref mut inner) => inner.get_mut(_proof), + } + } + + /// Tries to immutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get<'stack>( + &'stack self, + _proof: &'stack StackToken, + ) -> Result<&'stack T, InvalidThreadAccess> { + match self.inner { + SemiStickyImpl::Fragile(ref inner) => inner.try_get(), + SemiStickyImpl::Sticky(ref inner) => inner.try_get(_proof), + } + } + + /// Tries to mutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get_mut<'stack>( + &'stack mut self, + _proof: &'stack StackToken, + ) -> Result<&'stack mut T, InvalidThreadAccess> { + match self.inner { + SemiStickyImpl::Fragile(ref mut inner) => inner.try_get_mut(), + SemiStickyImpl::Sticky(ref mut inner) => inner.try_get_mut(_proof), + } + } +} + +impl<T> From<T> for SemiSticky<T> { + #[inline] + fn from(t: T) -> SemiSticky<T> { + SemiSticky::new(t) + } +} + +impl<T: Clone> Clone for SemiSticky<T> { + #[inline] + fn clone(&self) -> SemiSticky<T> { + crate::stack_token!(tok); + SemiSticky::new(self.get(tok).clone()) + } +} + +impl<T: Default> Default for SemiSticky<T> { + #[inline] + fn default() -> SemiSticky<T> { + SemiSticky::new(T::default()) + } +} + +impl<T: PartialEq> PartialEq for SemiSticky<T> { + #[inline] + fn eq(&self, other: &SemiSticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) == *other.get(tok) + } +} + +impl<T: Eq> Eq for SemiSticky<T> {} + +impl<T: PartialOrd> PartialOrd for SemiSticky<T> { + #[inline] + fn partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering> { + crate::stack_token!(tok); + self.get(tok).partial_cmp(other.get(tok)) + } + + #[inline] + fn lt(&self, other: &SemiSticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) < *other.get(tok) + } + + #[inline] + fn le(&self, other: &SemiSticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) <= *other.get(tok) + } + + #[inline] + fn gt(&self, other: &SemiSticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) > *other.get(tok) + } + + #[inline] + fn ge(&self, other: &SemiSticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) >= *other.get(tok) + } +} + +impl<T: Ord> Ord for SemiSticky<T> { + #[inline] + fn cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering { + crate::stack_token!(tok); + self.get(tok).cmp(other.get(tok)) + } +} + +impl<T: fmt::Display> fmt::Display for SemiSticky<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + crate::stack_token!(tok); + fmt::Display::fmt(self.get(tok), f) + } +} + +impl<T: fmt::Debug> fmt::Debug for SemiSticky<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + crate::stack_token!(tok); + match self.try_get(tok) { + Ok(value) => f.debug_struct("SemiSticky").field("value", value).finish(), + Err(..) => { + struct InvalidPlaceholder; + impl fmt::Debug for InvalidPlaceholder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("<invalid thread>") + } + } + + f.debug_struct("SemiSticky") + .field("value", &InvalidPlaceholder) + .finish() + } + } + } +} + +#[test] +fn test_basic() { + use std::thread; + let val = SemiSticky::new(true); + crate::stack_token!(tok); + assert_eq!(val.to_string(), "true"); + assert_eq!(val.get(tok), &true); + assert!(val.try_get(tok).is_ok()); + thread::spawn(move || { + crate::stack_token!(tok); + assert!(val.try_get(tok).is_err()); + }) + .join() + .unwrap(); +} + +#[test] +fn test_mut() { + let mut val = SemiSticky::new(true); + crate::stack_token!(tok); + *val.get_mut(tok) = false; + assert_eq!(val.to_string(), "false"); + assert_eq!(val.get(tok), &false); +} + +#[test] +#[should_panic] +fn test_access_other_thread() { + use std::thread; + let val = SemiSticky::new(true); + thread::spawn(move || { + crate::stack_token!(tok); + val.get(tok); + }) + .join() + .unwrap(); +} + +#[test] +fn test_drop_same_thread() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + let was_called = Arc::new(AtomicBool::new(false)); + struct X(Arc<AtomicBool>); + impl Drop for X { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + let val = SemiSticky::new(X(was_called.clone())); + mem::drop(val); + assert!(was_called.load(Ordering::SeqCst)); +} + +#[test] +fn test_noop_drop_elsewhere() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + + let was_called = Arc::new(AtomicBool::new(false)); + + { + let was_called = was_called.clone(); + thread::spawn(move || { + struct X(Arc<AtomicBool>); + impl Drop for X { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + + let val = SemiSticky::new(X(was_called.clone())); + assert!(thread::spawn(move || { + // moves it here but do not deallocate + crate::stack_token!(tok); + val.try_get(tok).ok(); + }) + .join() + .is_ok()); + + assert!(!was_called.load(Ordering::SeqCst)); + }) + .join() + .unwrap(); + } + + assert!(was_called.load(Ordering::SeqCst)); +} + +#[test] +fn test_rc_sending() { + use std::rc::Rc; + use std::thread; + let val = SemiSticky::new(Rc::new(true)); + thread::spawn(move || { + crate::stack_token!(tok); + assert!(val.try_get(tok).is_err()); + }) + .join() + .unwrap(); +} diff --git a/src/sticky.rs b/src/sticky.rs new file mode 100644 index 0000000..bc15c40 --- /dev/null +++ b/src/sticky.rs @@ -0,0 +1,423 @@ +#![allow(clippy::unit_arg)] + +use std::cmp; +use std::fmt; +use std::marker::PhantomData; +use std::mem; +use std::num::NonZeroUsize; + +use crate::errors::InvalidThreadAccess; +use crate::registry; +use crate::thread_id; +use crate::StackToken; + +/// A [`Sticky<T>`] keeps a value T stored in a thread. +/// +/// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a +/// similar interface. The difference is that whereas [`Fragile`](crate::Fragile) has +/// its destructor called in the thread where the value was sent, a +/// [`Sticky`] that is moved to another thread will have the internal +/// destructor called when the originating thread tears down. +/// +/// Because [`Sticky`] allows values to be kept alive for longer than the +/// [`Sticky`] itself, it requires all its contents to be `'static` for +/// soundness. More importantly it also requires the use of [`StackToken`]s. +/// For information about how to use stack tokens and why they are neded, +/// refer to [`stack_token!`](crate::stack_token). +/// +/// As this uses TLS internally the general rules about the platform limitations +/// of destructors for TLS apply. +pub struct Sticky<T: 'static> { + item_id: registry::ItemId, + thread_id: NonZeroUsize, + _marker: PhantomData<*mut T>, +} + +impl<T> Drop for Sticky<T> { + fn drop(&mut self) { + // if the type needs dropping we can only do so on the + // right thread. worst case we leak the value until the + // thread dies. + if mem::needs_drop::<T>() { + unsafe { + if self.is_valid() { + self.unsafe_take_value(); + } + } + + // otherwise we take the liberty to drop the value + // right here and now. We can however only do that if + // we are on the right thread. If we are not, we again + // need to wait for the thread to shut down. + } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) { + unsafe { + (entry.drop)(entry.ptr); + } + } + } +} + +impl<T> Sticky<T> { + /// Creates a new [`Sticky`] wrapping a `value`. + /// + /// The value that is moved into the [`Sticky`] can be non `Send` and + /// will be anchored to the thread that created the object. If the + /// sticky wrapper type ends up being send from thread to thread + /// only the original thread can interact with the value. + pub fn new(value: T) -> Self { + let entry = registry::Entry { + ptr: Box::into_raw(Box::new(value)).cast(), + drop: |ptr| { + let ptr = ptr.cast::<T>(); + // SAFETY: This callback will only be called once, with the + // above pointer. + drop(unsafe { Box::from_raw(ptr) }); + }, + }; + + let thread_id = thread_id::get(); + let item_id = registry::insert(thread_id, entry); + + Sticky { + item_id, + thread_id, + _marker: PhantomData, + } + } + + #[inline(always)] + fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R { + self.assert_thread(); + + registry::with(self.item_id, self.thread_id, |entry| { + f(entry.ptr.cast::<T>()) + }) + } + + /// Returns `true` if the access is valid. + /// + /// This will be `false` if the value was sent to another thread. + #[inline(always)] + pub fn is_valid(&self) -> bool { + thread_id::get() == self.thread_id + } + + #[inline(always)] + fn assert_thread(&self) { + if !self.is_valid() { + panic!("trying to access wrapped value in sticky container from incorrect thread."); + } + } + + /// Consumes the `Sticky`, returning the wrapped value. + /// + /// # Panics + /// + /// Panics if called from a different thread than the one where the + /// original value was created. + pub fn into_inner(mut self) -> T { + self.assert_thread(); + unsafe { + let rv = self.unsafe_take_value(); + mem::forget(self); + rv + } + } + + unsafe fn unsafe_take_value(&mut self) -> T { + let ptr = registry::remove(self.item_id, self.thread_id) + .ptr + .cast::<T>(); + *Box::from_raw(ptr) + } + + /// Consumes the `Sticky`, returning the wrapped value if successful. + /// + /// The wrapped value is returned if this is called from the same thread + /// as the one where the original value was created, otherwise the + /// `Sticky` is returned as `Err(self)`. + pub fn try_into_inner(self) -> Result<T, Self> { + if self.is_valid() { + Ok(self.into_inner()) + } else { + Err(self) + } + } + + /// Immutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get`](#method.try_get`). + pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T { + self.with_value(|value| unsafe { &*value }) + } + + /// Mutably borrows the wrapped value. + /// + /// # Panics + /// + /// Panics if the calling thread is not the one that wrapped the value. + /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`). + pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T { + self.with_value(|value| unsafe { &mut *value }) + } + + /// Tries to immutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get<'stack>( + &'stack self, + _proof: &'stack StackToken, + ) -> Result<&'stack T, InvalidThreadAccess> { + if self.is_valid() { + Ok(self.with_value(|value| unsafe { &*value })) + } else { + Err(InvalidThreadAccess) + } + } + + /// Tries to mutably borrow the wrapped value. + /// + /// Returns `None` if the calling thread is not the one that wrapped the value. + pub fn try_get_mut<'stack>( + &'stack mut self, + _proof: &'stack StackToken, + ) -> Result<&'stack mut T, InvalidThreadAccess> { + if self.is_valid() { + Ok(self.with_value(|value| unsafe { &mut *value })) + } else { + Err(InvalidThreadAccess) + } + } +} + +impl<T> From<T> for Sticky<T> { + #[inline] + fn from(t: T) -> Sticky<T> { + Sticky::new(t) + } +} + +impl<T: Clone> Clone for Sticky<T> { + #[inline] + fn clone(&self) -> Sticky<T> { + crate::stack_token!(tok); + Sticky::new(self.get(tok).clone()) + } +} + +impl<T: Default> Default for Sticky<T> { + #[inline] + fn default() -> Sticky<T> { + Sticky::new(T::default()) + } +} + +impl<T: PartialEq> PartialEq for Sticky<T> { + #[inline] + fn eq(&self, other: &Sticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) == *other.get(tok) + } +} + +impl<T: Eq> Eq for Sticky<T> {} + +impl<T: PartialOrd> PartialOrd for Sticky<T> { + #[inline] + fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> { + crate::stack_token!(tok); + self.get(tok).partial_cmp(other.get(tok)) + } + + #[inline] + fn lt(&self, other: &Sticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) < *other.get(tok) + } + + #[inline] + fn le(&self, other: &Sticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) <= *other.get(tok) + } + + #[inline] + fn gt(&self, other: &Sticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) > *other.get(tok) + } + + #[inline] + fn ge(&self, other: &Sticky<T>) -> bool { + crate::stack_token!(tok); + *self.get(tok) >= *other.get(tok) + } +} + +impl<T: Ord> Ord for Sticky<T> { + #[inline] + fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering { + crate::stack_token!(tok); + self.get(tok).cmp(other.get(tok)) + } +} + +impl<T: fmt::Display> fmt::Display for Sticky<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + crate::stack_token!(tok); + fmt::Display::fmt(self.get(tok), f) + } +} + +impl<T: fmt::Debug> fmt::Debug for Sticky<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + crate::stack_token!(tok); + match self.try_get(tok) { + Ok(value) => f.debug_struct("Sticky").field("value", value).finish(), + Err(..) => { + struct InvalidPlaceholder; + impl fmt::Debug for InvalidPlaceholder { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("<invalid thread>") + } + } + + f.debug_struct("Sticky") + .field("value", &InvalidPlaceholder) + .finish() + } + } + } +} + +// similar as for fragile ths type is sync because it only accesses TLS data +// which is thread local. There is nothing that needs to be synchronized. +unsafe impl<T> Sync for Sticky<T> {} + +// The entire point of this type is to be Send +unsafe impl<T> Send for Sticky<T> {} + +#[test] +fn test_basic() { + use std::thread; + let val = Sticky::new(true); + crate::stack_token!(tok); + assert_eq!(val.to_string(), "true"); + assert_eq!(val.get(tok), &true); + assert!(val.try_get(tok).is_ok()); + thread::spawn(move || { + crate::stack_token!(tok); + assert!(val.try_get(tok).is_err()); + }) + .join() + .unwrap(); +} + +#[test] +fn test_mut() { + let mut val = Sticky::new(true); + crate::stack_token!(tok); + *val.get_mut(tok) = false; + assert_eq!(val.to_string(), "false"); + assert_eq!(val.get(tok), &false); +} + +#[test] +#[should_panic] +fn test_access_other_thread() { + use std::thread; + let val = Sticky::new(true); + thread::spawn(move || { + crate::stack_token!(tok); + val.get(tok); + }) + .join() + .unwrap(); +} + +#[test] +fn test_drop_same_thread() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + let was_called = Arc::new(AtomicBool::new(false)); + struct X(Arc<AtomicBool>); + impl Drop for X { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + let val = Sticky::new(X(was_called.clone())); + mem::drop(val); + assert!(was_called.load(Ordering::SeqCst)); +} + +#[test] +fn test_noop_drop_elsewhere() { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + + let was_called = Arc::new(AtomicBool::new(false)); + + { + let was_called = was_called.clone(); + thread::spawn(move || { + struct X(Arc<AtomicBool>); + impl Drop for X { + fn drop(&mut self) { + self.0.store(true, Ordering::SeqCst); + } + } + + let val = Sticky::new(X(was_called.clone())); + assert!(thread::spawn(move || { + // moves it here but do not deallocate + crate::stack_token!(tok); + val.try_get(tok).ok(); + }) + .join() + .is_ok()); + + assert!(!was_called.load(Ordering::SeqCst)); + }) + .join() + .unwrap(); + } + + assert!(was_called.load(Ordering::SeqCst)); +} + +#[test] +fn test_rc_sending() { + use std::rc::Rc; + use std::thread; + let val = Sticky::new(Rc::new(true)); + thread::spawn(move || { + crate::stack_token!(tok); + assert!(val.try_get(tok).is_err()); + }) + .join() + .unwrap(); +} + +#[test] +fn test_two_stickies() { + struct Wat; + + impl Drop for Wat { + fn drop(&mut self) { + // do nothing + } + } + + let s1 = Sticky::new(Wat); + let s2 = Sticky::new(Wat); + + // make sure all is well + + drop(s1); + drop(s2); +} diff --git a/src/thread_id.rs b/src/thread_id.rs new file mode 100644 index 0000000..00468b2 --- /dev/null +++ b/src/thread_id.rs @@ -0,0 +1,12 @@ +use std::num::NonZeroUsize; +use std::sync::atomic::{AtomicUsize, Ordering}; + +fn next() -> NonZeroUsize { + static COUNTER: AtomicUsize = AtomicUsize::new(1); + NonZeroUsize::new(COUNTER.fetch_add(1, Ordering::SeqCst)).expect("more than usize::MAX threads") +} + +pub(crate) fn get() -> NonZeroUsize { + thread_local!(static THREAD_ID: NonZeroUsize = next()); + THREAD_ID.with(|&x| x) +} |