summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/errors.rs14
-rw-r--r--src/fragile.rs326
-rw-r--r--src/lib.rs157
-rw-r--r--src/registry.rs104
-rw-r--r--src/semisticky.rs339
-rw-r--r--src/sticky.rs423
-rw-r--r--src/thread_id.rs12
7 files changed, 1375 insertions, 0 deletions
diff --git a/src/errors.rs b/src/errors.rs
new file mode 100644
index 0000000..bf3fb4c
--- /dev/null
+++ b/src/errors.rs
@@ -0,0 +1,14 @@
+use std::error;
+use std::fmt;
+
+/// Returned when borrowing fails.
+#[derive(Debug)]
+pub struct InvalidThreadAccess;
+
+impl fmt::Display for InvalidThreadAccess {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "fragile value accessed from foreign thread")
+ }
+}
+
+impl error::Error for InvalidThreadAccess {}
diff --git a/src/fragile.rs b/src/fragile.rs
new file mode 100644
index 0000000..92eb3d3
--- /dev/null
+++ b/src/fragile.rs
@@ -0,0 +1,326 @@
+use std::cmp;
+use std::fmt;
+use std::mem;
+use std::num::NonZeroUsize;
+
+use crate::errors::InvalidThreadAccess;
+use crate::thread_id;
+use std::mem::ManuallyDrop;
+
+/// A [`Fragile<T>`] wraps a non sendable `T` to be safely send to other threads.
+///
+/// Once the value has been wrapped it can be sent to other threads but access
+/// to the value on those threads will fail.
+///
+/// If the value needs destruction and the fragile wrapper is on another thread
+/// the destructor will panic. Alternatively you can use
+/// [`Sticky`](crate::Sticky) which is not going to panic but might temporarily
+/// leak the value.
+pub struct Fragile<T> {
+ // ManuallyDrop is necessary because we need to move out of here without running the
+ // Drop code in functions like `into_inner`.
+ value: ManuallyDrop<T>,
+ thread_id: NonZeroUsize,
+}
+
+impl<T> Fragile<T> {
+ /// Creates a new [`Fragile`] wrapping a `value`.
+ ///
+ /// The value that is moved into the [`Fragile`] can be non `Send` and
+ /// will be anchored to the thread that created the object. If the
+ /// fragile wrapper type ends up being send from thread to thread
+ /// only the original thread can interact with the value.
+ pub fn new(value: T) -> Self {
+ Fragile {
+ value: ManuallyDrop::new(value),
+ thread_id: thread_id::get(),
+ }
+ }
+
+ /// Returns `true` if the access is valid.
+ ///
+ /// This will be `false` if the value was sent to another thread.
+ pub fn is_valid(&self) -> bool {
+ thread_id::get() == self.thread_id
+ }
+
+ #[inline(always)]
+ fn assert_thread(&self) {
+ if !self.is_valid() {
+ panic!("trying to access wrapped value in fragile container from incorrect thread.");
+ }
+ }
+
+ /// Consumes the `Fragile`, returning the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if called from a different thread than the one where the
+ /// original value was created.
+ pub fn into_inner(self) -> T {
+ self.assert_thread();
+
+ let mut this = ManuallyDrop::new(self);
+
+ // SAFETY: `this` is not accessed beyond this point, and because it's in a ManuallyDrop its
+ // destructor is not run.
+ unsafe { ManuallyDrop::take(&mut this.value) }
+ }
+
+ /// Consumes the `Fragile`, returning the wrapped value if successful.
+ ///
+ /// The wrapped value is returned if this is called from the same thread
+ /// as the one where the original value was created, otherwise the
+ /// [`Fragile`] is returned as `Err(self)`.
+ pub fn try_into_inner(self) -> Result<T, Self> {
+ if thread_id::get() == self.thread_id {
+ Ok(self.into_inner())
+ } else {
+ Err(self)
+ }
+ }
+
+ /// Immutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get`](Self::try_get).
+ pub fn get(&self) -> &T {
+ self.assert_thread();
+ &*self.value
+ }
+
+ /// Mutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
+ pub fn get_mut(&mut self) -> &mut T {
+ self.assert_thread();
+ &mut *self.value
+ }
+
+ /// Tries to immutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
+ if thread_id::get() == self.thread_id {
+ Ok(&*self.value)
+ } else {
+ Err(InvalidThreadAccess)
+ }
+ }
+
+ /// Tries to mutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
+ if thread_id::get() == self.thread_id {
+ Ok(&mut *self.value)
+ } else {
+ Err(InvalidThreadAccess)
+ }
+ }
+}
+
+impl<T> Drop for Fragile<T> {
+ fn drop(&mut self) {
+ if mem::needs_drop::<T>() {
+ if thread_id::get() == self.thread_id {
+ // SAFETY: `ManuallyDrop::drop` cannot be called after this point.
+ unsafe { ManuallyDrop::drop(&mut self.value) };
+ } else {
+ panic!("destructor of fragile object ran on wrong thread");
+ }
+ }
+ }
+}
+
+impl<T> From<T> for Fragile<T> {
+ #[inline]
+ fn from(t: T) -> Fragile<T> {
+ Fragile::new(t)
+ }
+}
+
+impl<T: Clone> Clone for Fragile<T> {
+ #[inline]
+ fn clone(&self) -> Fragile<T> {
+ Fragile::new(self.get().clone())
+ }
+}
+
+impl<T: Default> Default for Fragile<T> {
+ #[inline]
+ fn default() -> Fragile<T> {
+ Fragile::new(T::default())
+ }
+}
+
+impl<T: PartialEq> PartialEq for Fragile<T> {
+ #[inline]
+ fn eq(&self, other: &Fragile<T>) -> bool {
+ *self.get() == *other.get()
+ }
+}
+
+impl<T: Eq> Eq for Fragile<T> {}
+
+impl<T: PartialOrd> PartialOrd for Fragile<T> {
+ #[inline]
+ fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
+ self.get().partial_cmp(other.get())
+ }
+
+ #[inline]
+ fn lt(&self, other: &Fragile<T>) -> bool {
+ *self.get() < *other.get()
+ }
+
+ #[inline]
+ fn le(&self, other: &Fragile<T>) -> bool {
+ *self.get() <= *other.get()
+ }
+
+ #[inline]
+ fn gt(&self, other: &Fragile<T>) -> bool {
+ *self.get() > *other.get()
+ }
+
+ #[inline]
+ fn ge(&self, other: &Fragile<T>) -> bool {
+ *self.get() >= *other.get()
+ }
+}
+
+impl<T: Ord> Ord for Fragile<T> {
+ #[inline]
+ fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
+ self.get().cmp(other.get())
+ }
+}
+
+impl<T: fmt::Display> fmt::Display for Fragile<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ fmt::Display::fmt(self.get(), f)
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ match self.try_get() {
+ Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
+ Err(..) => {
+ struct InvalidPlaceholder;
+ impl fmt::Debug for InvalidPlaceholder {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("<invalid thread>")
+ }
+ }
+
+ f.debug_struct("Fragile")
+ .field("value", &InvalidPlaceholder)
+ .finish()
+ }
+ }
+ }
+}
+
+// this type is sync because access can only ever happy from the same thread
+// that created it originally. All other threads will be able to safely
+// call some basic operations on the reference and they will fail.
+unsafe impl<T> Sync for Fragile<T> {}
+
+// The entire point of this type is to be Send
+#[allow(clippy::non_send_fields_in_send_ty)]
+unsafe impl<T> Send for Fragile<T> {}
+
+#[test]
+fn test_basic() {
+ use std::thread;
+ let val = Fragile::new(true);
+ assert_eq!(val.to_string(), "true");
+ assert_eq!(val.get(), &true);
+ assert!(val.try_get().is_ok());
+ thread::spawn(move || {
+ assert!(val.try_get().is_err());
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_mut() {
+ let mut val = Fragile::new(true);
+ *val.get_mut() = false;
+ assert_eq!(val.to_string(), "false");
+ assert_eq!(val.get(), &false);
+}
+
+#[test]
+#[should_panic]
+fn test_access_other_thread() {
+ use std::thread;
+ let val = Fragile::new(true);
+ thread::spawn(move || {
+ val.get();
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_noop_drop_elsewhere() {
+ use std::thread;
+ let val = Fragile::new(true);
+ thread::spawn(move || {
+ // force the move
+ val.try_get().ok();
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_panic_on_drop_elsewhere() {
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::Arc;
+ use std::thread;
+ let was_called = Arc::new(AtomicBool::new(false));
+ struct X(Arc<AtomicBool>);
+ impl Drop for X {
+ fn drop(&mut self) {
+ self.0.store(true, Ordering::SeqCst);
+ }
+ }
+ let val = Fragile::new(X(was_called.clone()));
+ assert!(thread::spawn(move || {
+ val.try_get().ok();
+ })
+ .join()
+ .is_err());
+ assert!(!was_called.load(Ordering::SeqCst));
+}
+
+#[test]
+fn test_rc_sending() {
+ use std::rc::Rc;
+ use std::sync::mpsc::channel;
+ use std::thread;
+
+ let val = Fragile::new(Rc::new(true));
+ let (tx, rx) = channel();
+
+ let thread = thread::spawn(move || {
+ assert!(val.try_get().is_err());
+ let here = val;
+ tx.send(here).unwrap();
+ });
+
+ let rv = rx.recv().unwrap();
+ assert!(**rv.get());
+
+ thread.join().unwrap();
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..16edc9d
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,157 @@
+//! This library provides wrapper types that permit sending non `Send` types to
+//! other threads and use runtime checks to ensure safety.
+//!
+//! It provides three types: [`Fragile`] and [`Sticky`] which are similar in nature
+//! but have different behaviors with regards to how destructors are executed and
+//! the extra [`SemiSticky`] type which uses [`Sticky`] if the value has a
+//! destructor and [`Fragile`] if it does not.
+//!
+//! All three types wrap a value and provide a `Send` bound. Neither of the types permit
+//! access to the enclosed value unless the thread that wrapped the value is attempting
+//! to access it. The difference between the types starts playing a role once
+//! destructors are involved.
+//!
+//! A [`Fragile`] will actually send the `T` from thread to thread but will only
+//! permit the original thread to invoke the destructor. If the value gets dropped
+//! in a different thread, the destructor will panic.
+//!
+//! A [`Sticky`] on the other hand does not actually send the `T` around but keeps
+//! it stored in the original thread's thread local storage. If it gets dropped
+//! in the originating thread it gets cleaned up immediately, otherwise it leaks
+//! until the thread shuts down naturally. [`Sticky`] because it borrows into the
+//! TLS also requires you to "prove" that you are not doing any funny business with
+//! the borrowed value that lives for longer than the current stack frame which
+//! results in a slightly more complex API.
+//!
+//! There is a third typed called [`SemiSticky`] which shares the API with [`Sticky`]
+//! but internally uses a boxed [`Fragile`] if the type does not actually need a dtor
+//! in which case [`Fragile`] is preferred.
+//!
+//! # Fragile Usage
+//!
+//! [`Fragile`] is the easiest type to use. It works almost like a cell.
+//!
+//! ```
+//! use std::thread;
+//! use fragile::Fragile;
+//!
+//! // creating and using a fragile object in the same thread works
+//! let val = Fragile::new(true);
+//! assert_eq!(*val.get(), true);
+//! assert!(val.try_get().is_ok());
+//!
+//! // once send to another thread it stops working
+//! thread::spawn(move || {
+//! assert!(val.try_get().is_err());
+//! }).join()
+//! .unwrap();
+//! ```
+//!
+//! # Sticky Usage
+//!
+//! [`Sticky`] is similar to [`Fragile`] but because it places the value in the
+//! thread local storage it comes with some extra restrictions to make it sound.
+//! The advantage is it can be dropped from any thread but it comes with extra
+//! restrictions. In particular it requires that values placed in it are `'static`
+//! and that [`StackToken`]s are used to restrict lifetimes.
+//!
+//! ```
+//! use std::thread;
+//! use fragile::Sticky;
+//!
+//! // creating and using a fragile object in the same thread works
+//! fragile::stack_token!(tok);
+//! let val = Sticky::new(true);
+//! assert_eq!(*val.get(tok), true);
+//! assert!(val.try_get(tok).is_ok());
+//!
+//! // once send to another thread it stops working
+//! thread::spawn(move || {
+//! fragile::stack_token!(tok);
+//! assert!(val.try_get(tok).is_err());
+//! }).join()
+//! .unwrap();
+//! ```
+//!
+//! # Why?
+//!
+//! Most of the time trying to use this crate is going to indicate some code smell. But
+//! there are situations where this is useful. For instance you might have a bunch of
+//! non `Send` types but want to work with a `Send` error type. In that case the non
+//! sendable extra information can be contained within the error and in cases where the
+//! error did not cross a thread boundary yet extra information can be obtained.
+//!
+//! # Drop / Cleanup Behavior
+//!
+//! All types will try to eagerly drop a value if they are dropped on the right thread.
+//! [`Sticky`] and [`SemiSticky`] will however temporarily leak memory until a thread
+//! shuts down if the value is dropped on the wrong thread. The benefit however is that
+//! if you have that type of situation, and you can live with the consequences, the
+//! type is not panicking. A [`Fragile`] dropped in the wrong thread will not just panic,
+//! it will effectively also tear down the process because panicking in destructors is
+//! non recoverable.
+//!
+//! # Features
+//!
+//! By default the crate has no dependencies. Optionally the `slab` feature can
+//! be enabled which optimizes the internal storage of the [`Sticky`] type to
+//! make it use a [`slab`](https://docs.rs/slab/latest/slab/) instead.
+mod errors;
+mod fragile;
+mod registry;
+mod semisticky;
+mod sticky;
+mod thread_id;
+
+use std::marker::PhantomData;
+
+pub use crate::errors::InvalidThreadAccess;
+pub use crate::fragile::Fragile;
+pub use crate::semisticky::SemiSticky;
+pub use crate::sticky::Sticky;
+
+/// A token that is placed to the stack to constrain lifetimes.
+///
+/// For more information about how these work see the documentation of
+/// [`stack_token!`] which is the only way to create this token.
+pub struct StackToken(PhantomData<*const ()>);
+
+impl StackToken {
+ /// Stack tokens must only be created on the stack.
+ #[doc(hidden)]
+ pub unsafe fn __private_new() -> StackToken {
+ // we place a const pointer in there to get a type
+ // that is neither Send nor Sync.
+ StackToken(PhantomData)
+ }
+}
+
+/// Crates a token on the stack with a certain name for semi-sticky.
+///
+/// The argument to the macro is the target name of a local variable
+/// which holds a reference to a stack token. Because this is the
+/// only way to create such a token, it acts as a proof to [`Sticky`]
+/// or [`SemiSticky`] that can be used to constrain the lifetime of the
+/// return values to the stack frame.
+///
+/// This is necessary as otherwise a [`Sticky`] placed in a [`Box`] and
+/// leaked with [`Box::leak`] (which creates a static lifetime) would
+/// otherwise create a reference with `'static` lifetime. This is incorrect
+/// as the actual lifetime is constrained to the lifetime of the thread.
+/// For more information see [`issue 26`](https://github.com/mitsuhiko/fragile/issues/26).
+///
+/// ```rust
+/// let sticky = fragile::Sticky::new(true);
+///
+/// // this places a token on the stack.
+/// fragile::stack_token!(my_token);
+///
+/// // the token needs to be passed to `get` and others.
+/// let _ = sticky.get(my_token);
+/// ```
+#[macro_export]
+macro_rules! stack_token {
+ ($name:ident) => {
+ let $name = &unsafe { $crate::StackToken::__private_new() };
+ };
+}
diff --git a/src/registry.rs b/src/registry.rs
new file mode 100644
index 0000000..1ee070d
--- /dev/null
+++ b/src/registry.rs
@@ -0,0 +1,104 @@
+pub struct Entry {
+ /// The pointer to the object stored in the registry. This is a type-erased
+ /// `Box<T>`.
+ pub ptr: *mut (),
+ /// The function that can be called on the above pointer to drop the object
+ /// and free its allocation.
+ pub drop: unsafe fn(*mut ()),
+}
+
+#[cfg(feature = "slab")]
+mod slab_impl {
+ use std::cell::UnsafeCell;
+ use std::num::NonZeroUsize;
+
+ use super::Entry;
+
+ pub struct Registry(pub slab::Slab<Entry>);
+
+ thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(slab::Slab::new())));
+
+ pub use usize as ItemId;
+
+ pub fn insert(thread_id: NonZeroUsize, entry: Entry) -> ItemId {
+ let _ = thread_id;
+ REGISTRY.with(|registry| unsafe { (*registry.get()).0.insert(entry) })
+ }
+
+ pub fn with<R, F: FnOnce(&Entry) -> R>(item_id: ItemId, thread_id: NonZeroUsize, f: F) -> R {
+ let _ = thread_id;
+ REGISTRY.with(|registry| f(unsafe { &*registry.get() }.0.get(item_id).unwrap()))
+ }
+
+ pub fn remove(item_id: ItemId, thread_id: NonZeroUsize) -> Entry {
+ let _ = thread_id;
+ REGISTRY.with(|registry| unsafe { (*registry.get()).0.remove(item_id) })
+ }
+
+ pub fn try_remove(item_id: ItemId, thread_id: NonZeroUsize) -> Option<Entry> {
+ let _ = thread_id;
+ REGISTRY.with(|registry| unsafe { (*registry.get()).0.try_remove(item_id) })
+ }
+}
+
+#[cfg(not(feature = "slab"))]
+mod map_impl {
+ use std::cell::UnsafeCell;
+ use std::num::NonZeroUsize;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+
+ use super::Entry;
+
+ pub struct Registry(pub std::collections::HashMap<(NonZeroUsize, NonZeroUsize), Entry>);
+
+ thread_local!(static REGISTRY: UnsafeCell<Registry> = UnsafeCell::new(Registry(Default::default())));
+
+ pub type ItemId = NonZeroUsize;
+
+ fn next_item_id() -> NonZeroUsize {
+ static COUNTER: AtomicUsize = AtomicUsize::new(1);
+ NonZeroUsize::new(COUNTER.fetch_add(1, Ordering::SeqCst))
+ .expect("more than usize::MAX items")
+ }
+
+ pub fn insert(thread_id: NonZeroUsize, entry: Entry) -> ItemId {
+ let item_id = next_item_id();
+ REGISTRY
+ .with(|registry| unsafe { (*registry.get()).0.insert((thread_id, item_id), entry) });
+ item_id
+ }
+
+ pub fn with<R, F: FnOnce(&Entry) -> R>(item_id: ItemId, thread_id: NonZeroUsize, f: F) -> R {
+ REGISTRY.with(|registry| {
+ f(unsafe { &*registry.get() }
+ .0
+ .get(&(thread_id, item_id))
+ .unwrap())
+ })
+ }
+
+ pub fn remove(item_id: ItemId, thread_id: NonZeroUsize) -> Entry {
+ REGISTRY
+ .with(|registry| unsafe { (*registry.get()).0.remove(&(thread_id, item_id)).unwrap() })
+ }
+
+ pub fn try_remove(item_id: ItemId, thread_id: NonZeroUsize) -> Option<Entry> {
+ REGISTRY.with(|registry| unsafe { (*registry.get()).0.remove(&(thread_id, item_id)) })
+ }
+}
+
+#[cfg(feature = "slab")]
+pub use self::slab_impl::*;
+
+#[cfg(not(feature = "slab"))]
+pub use self::map_impl::*;
+
+impl Drop for Registry {
+ fn drop(&mut self) {
+ for (_, value) in self.0.iter() {
+ // SAFETY: This function is only called once, and is called with the
+ // pointer it was created with.
+ unsafe { (value.drop)(value.ptr) };
+ }
+ }
+}
diff --git a/src/semisticky.rs b/src/semisticky.rs
new file mode 100644
index 0000000..2b6c0f4
--- /dev/null
+++ b/src/semisticky.rs
@@ -0,0 +1,339 @@
+use std::cmp;
+use std::fmt;
+use std::mem;
+
+use crate::errors::InvalidThreadAccess;
+use crate::fragile::Fragile;
+use crate::sticky::Sticky;
+use crate::StackToken;
+
+enum SemiStickyImpl<T: 'static> {
+ Fragile(Box<Fragile<T>>),
+ Sticky(Sticky<T>),
+}
+
+/// A [`SemiSticky<T>`] keeps a value T stored in a thread if it has a drop.
+///
+/// This is a combined version of [`Fragile`] and [`Sticky`]. If the type
+/// does not have a drop it will effectively be a [`Fragile`], otherwise it
+/// will be internally behave like a [`Sticky`].
+///
+/// This type requires `T: 'static` for the same reasons as [`Sticky`] and
+/// also uses [`StackToken`]s.
+pub struct SemiSticky<T: 'static> {
+ inner: SemiStickyImpl<T>,
+}
+
+impl<T> SemiSticky<T> {
+ /// Creates a new [`SemiSticky`] wrapping a `value`.
+ ///
+ /// The value that is moved into the `SemiSticky` can be non `Send` and
+ /// will be anchored to the thread that created the object. If the
+ /// sticky wrapper type ends up being send from thread to thread
+ /// only the original thread can interact with the value. In case the
+ /// value does not have `Drop` it will be stored in the [`SemiSticky`]
+ /// instead.
+ pub fn new(value: T) -> Self {
+ SemiSticky {
+ inner: if mem::needs_drop::<T>() {
+ SemiStickyImpl::Sticky(Sticky::new(value))
+ } else {
+ SemiStickyImpl::Fragile(Box::new(Fragile::new(value)))
+ },
+ }
+ }
+
+ /// Returns `true` if the access is valid.
+ ///
+ /// This will be `false` if the value was sent to another thread.
+ pub fn is_valid(&self) -> bool {
+ match self.inner {
+ SemiStickyImpl::Fragile(ref inner) => inner.is_valid(),
+ SemiStickyImpl::Sticky(ref inner) => inner.is_valid(),
+ }
+ }
+
+ /// Consumes the [`SemiSticky`], returning the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if called from a different thread than the one where the
+ /// original value was created.
+ pub fn into_inner(self) -> T {
+ match self.inner {
+ SemiStickyImpl::Fragile(inner) => inner.into_inner(),
+ SemiStickyImpl::Sticky(inner) => inner.into_inner(),
+ }
+ }
+
+ /// Consumes the [`SemiSticky`], returning the wrapped value if successful.
+ ///
+ /// The wrapped value is returned if this is called from the same thread
+ /// as the one where the original value was created, otherwise the
+ /// [`SemiSticky`] is returned as `Err(self)`.
+ pub fn try_into_inner(self) -> Result<T, Self> {
+ match self.inner {
+ SemiStickyImpl::Fragile(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
+ inner: SemiStickyImpl::Fragile(Box::new(inner)),
+ }),
+ SemiStickyImpl::Sticky(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
+ inner: SemiStickyImpl::Sticky(inner),
+ }),
+ }
+ }
+
+ /// Immutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get`](Self::try_get).
+ pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
+ match self.inner {
+ SemiStickyImpl::Fragile(ref inner) => inner.get(),
+ SemiStickyImpl::Sticky(ref inner) => inner.get(_proof),
+ }
+ }
+
+ /// Mutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
+ pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
+ match self.inner {
+ SemiStickyImpl::Fragile(ref mut inner) => inner.get_mut(),
+ SemiStickyImpl::Sticky(ref mut inner) => inner.get_mut(_proof),
+ }
+ }
+
+ /// Tries to immutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get<'stack>(
+ &'stack self,
+ _proof: &'stack StackToken,
+ ) -> Result<&'stack T, InvalidThreadAccess> {
+ match self.inner {
+ SemiStickyImpl::Fragile(ref inner) => inner.try_get(),
+ SemiStickyImpl::Sticky(ref inner) => inner.try_get(_proof),
+ }
+ }
+
+ /// Tries to mutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get_mut<'stack>(
+ &'stack mut self,
+ _proof: &'stack StackToken,
+ ) -> Result<&'stack mut T, InvalidThreadAccess> {
+ match self.inner {
+ SemiStickyImpl::Fragile(ref mut inner) => inner.try_get_mut(),
+ SemiStickyImpl::Sticky(ref mut inner) => inner.try_get_mut(_proof),
+ }
+ }
+}
+
+impl<T> From<T> for SemiSticky<T> {
+ #[inline]
+ fn from(t: T) -> SemiSticky<T> {
+ SemiSticky::new(t)
+ }
+}
+
+impl<T: Clone> Clone for SemiSticky<T> {
+ #[inline]
+ fn clone(&self) -> SemiSticky<T> {
+ crate::stack_token!(tok);
+ SemiSticky::new(self.get(tok).clone())
+ }
+}
+
+impl<T: Default> Default for SemiSticky<T> {
+ #[inline]
+ fn default() -> SemiSticky<T> {
+ SemiSticky::new(T::default())
+ }
+}
+
+impl<T: PartialEq> PartialEq for SemiSticky<T> {
+ #[inline]
+ fn eq(&self, other: &SemiSticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) == *other.get(tok)
+ }
+}
+
+impl<T: Eq> Eq for SemiSticky<T> {}
+
+impl<T: PartialOrd> PartialOrd for SemiSticky<T> {
+ #[inline]
+ fn partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering> {
+ crate::stack_token!(tok);
+ self.get(tok).partial_cmp(other.get(tok))
+ }
+
+ #[inline]
+ fn lt(&self, other: &SemiSticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) < *other.get(tok)
+ }
+
+ #[inline]
+ fn le(&self, other: &SemiSticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) <= *other.get(tok)
+ }
+
+ #[inline]
+ fn gt(&self, other: &SemiSticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) > *other.get(tok)
+ }
+
+ #[inline]
+ fn ge(&self, other: &SemiSticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) >= *other.get(tok)
+ }
+}
+
+impl<T: Ord> Ord for SemiSticky<T> {
+ #[inline]
+ fn cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering {
+ crate::stack_token!(tok);
+ self.get(tok).cmp(other.get(tok))
+ }
+}
+
+impl<T: fmt::Display> fmt::Display for SemiSticky<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ crate::stack_token!(tok);
+ fmt::Display::fmt(self.get(tok), f)
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for SemiSticky<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ crate::stack_token!(tok);
+ match self.try_get(tok) {
+ Ok(value) => f.debug_struct("SemiSticky").field("value", value).finish(),
+ Err(..) => {
+ struct InvalidPlaceholder;
+ impl fmt::Debug for InvalidPlaceholder {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("<invalid thread>")
+ }
+ }
+
+ f.debug_struct("SemiSticky")
+ .field("value", &InvalidPlaceholder)
+ .finish()
+ }
+ }
+ }
+}
+
+#[test]
+fn test_basic() {
+ use std::thread;
+ let val = SemiSticky::new(true);
+ crate::stack_token!(tok);
+ assert_eq!(val.to_string(), "true");
+ assert_eq!(val.get(tok), &true);
+ assert!(val.try_get(tok).is_ok());
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ assert!(val.try_get(tok).is_err());
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_mut() {
+ let mut val = SemiSticky::new(true);
+ crate::stack_token!(tok);
+ *val.get_mut(tok) = false;
+ assert_eq!(val.to_string(), "false");
+ assert_eq!(val.get(tok), &false);
+}
+
+#[test]
+#[should_panic]
+fn test_access_other_thread() {
+ use std::thread;
+ let val = SemiSticky::new(true);
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ val.get(tok);
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_drop_same_thread() {
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::Arc;
+ let was_called = Arc::new(AtomicBool::new(false));
+ struct X(Arc<AtomicBool>);
+ impl Drop for X {
+ fn drop(&mut self) {
+ self.0.store(true, Ordering::SeqCst);
+ }
+ }
+ let val = SemiSticky::new(X(was_called.clone()));
+ mem::drop(val);
+ assert!(was_called.load(Ordering::SeqCst));
+}
+
+#[test]
+fn test_noop_drop_elsewhere() {
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::Arc;
+ use std::thread;
+
+ let was_called = Arc::new(AtomicBool::new(false));
+
+ {
+ let was_called = was_called.clone();
+ thread::spawn(move || {
+ struct X(Arc<AtomicBool>);
+ impl Drop for X {
+ fn drop(&mut self) {
+ self.0.store(true, Ordering::SeqCst);
+ }
+ }
+
+ let val = SemiSticky::new(X(was_called.clone()));
+ assert!(thread::spawn(move || {
+ // moves it here but do not deallocate
+ crate::stack_token!(tok);
+ val.try_get(tok).ok();
+ })
+ .join()
+ .is_ok());
+
+ assert!(!was_called.load(Ordering::SeqCst));
+ })
+ .join()
+ .unwrap();
+ }
+
+ assert!(was_called.load(Ordering::SeqCst));
+}
+
+#[test]
+fn test_rc_sending() {
+ use std::rc::Rc;
+ use std::thread;
+ let val = SemiSticky::new(Rc::new(true));
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ assert!(val.try_get(tok).is_err());
+ })
+ .join()
+ .unwrap();
+}
diff --git a/src/sticky.rs b/src/sticky.rs
new file mode 100644
index 0000000..bc15c40
--- /dev/null
+++ b/src/sticky.rs
@@ -0,0 +1,423 @@
+#![allow(clippy::unit_arg)]
+
+use std::cmp;
+use std::fmt;
+use std::marker::PhantomData;
+use std::mem;
+use std::num::NonZeroUsize;
+
+use crate::errors::InvalidThreadAccess;
+use crate::registry;
+use crate::thread_id;
+use crate::StackToken;
+
+/// A [`Sticky<T>`] keeps a value T stored in a thread.
+///
+/// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a
+/// similar interface. The difference is that whereas [`Fragile`](crate::Fragile) has
+/// its destructor called in the thread where the value was sent, a
+/// [`Sticky`] that is moved to another thread will have the internal
+/// destructor called when the originating thread tears down.
+///
+/// Because [`Sticky`] allows values to be kept alive for longer than the
+/// [`Sticky`] itself, it requires all its contents to be `'static` for
+/// soundness. More importantly it also requires the use of [`StackToken`]s.
+/// For information about how to use stack tokens and why they are neded,
+/// refer to [`stack_token!`](crate::stack_token).
+///
+/// As this uses TLS internally the general rules about the platform limitations
+/// of destructors for TLS apply.
+pub struct Sticky<T: 'static> {
+ item_id: registry::ItemId,
+ thread_id: NonZeroUsize,
+ _marker: PhantomData<*mut T>,
+}
+
+impl<T> Drop for Sticky<T> {
+ fn drop(&mut self) {
+ // if the type needs dropping we can only do so on the
+ // right thread. worst case we leak the value until the
+ // thread dies.
+ if mem::needs_drop::<T>() {
+ unsafe {
+ if self.is_valid() {
+ self.unsafe_take_value();
+ }
+ }
+
+ // otherwise we take the liberty to drop the value
+ // right here and now. We can however only do that if
+ // we are on the right thread. If we are not, we again
+ // need to wait for the thread to shut down.
+ } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) {
+ unsafe {
+ (entry.drop)(entry.ptr);
+ }
+ }
+ }
+}
+
+impl<T> Sticky<T> {
+ /// Creates a new [`Sticky`] wrapping a `value`.
+ ///
+ /// The value that is moved into the [`Sticky`] can be non `Send` and
+ /// will be anchored to the thread that created the object. If the
+ /// sticky wrapper type ends up being send from thread to thread
+ /// only the original thread can interact with the value.
+ pub fn new(value: T) -> Self {
+ let entry = registry::Entry {
+ ptr: Box::into_raw(Box::new(value)).cast(),
+ drop: |ptr| {
+ let ptr = ptr.cast::<T>();
+ // SAFETY: This callback will only be called once, with the
+ // above pointer.
+ drop(unsafe { Box::from_raw(ptr) });
+ },
+ };
+
+ let thread_id = thread_id::get();
+ let item_id = registry::insert(thread_id, entry);
+
+ Sticky {
+ item_id,
+ thread_id,
+ _marker: PhantomData,
+ }
+ }
+
+ #[inline(always)]
+ fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
+ self.assert_thread();
+
+ registry::with(self.item_id, self.thread_id, |entry| {
+ f(entry.ptr.cast::<T>())
+ })
+ }
+
+ /// Returns `true` if the access is valid.
+ ///
+ /// This will be `false` if the value was sent to another thread.
+ #[inline(always)]
+ pub fn is_valid(&self) -> bool {
+ thread_id::get() == self.thread_id
+ }
+
+ #[inline(always)]
+ fn assert_thread(&self) {
+ if !self.is_valid() {
+ panic!("trying to access wrapped value in sticky container from incorrect thread.");
+ }
+ }
+
+ /// Consumes the `Sticky`, returning the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if called from a different thread than the one where the
+ /// original value was created.
+ pub fn into_inner(mut self) -> T {
+ self.assert_thread();
+ unsafe {
+ let rv = self.unsafe_take_value();
+ mem::forget(self);
+ rv
+ }
+ }
+
+ unsafe fn unsafe_take_value(&mut self) -> T {
+ let ptr = registry::remove(self.item_id, self.thread_id)
+ .ptr
+ .cast::<T>();
+ *Box::from_raw(ptr)
+ }
+
+ /// Consumes the `Sticky`, returning the wrapped value if successful.
+ ///
+ /// The wrapped value is returned if this is called from the same thread
+ /// as the one where the original value was created, otherwise the
+ /// `Sticky` is returned as `Err(self)`.
+ pub fn try_into_inner(self) -> Result<T, Self> {
+ if self.is_valid() {
+ Ok(self.into_inner())
+ } else {
+ Err(self)
+ }
+ }
+
+ /// Immutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get`](#method.try_get`).
+ pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
+ self.with_value(|value| unsafe { &*value })
+ }
+
+ /// Mutably borrows the wrapped value.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the calling thread is not the one that wrapped the value.
+ /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
+ pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
+ self.with_value(|value| unsafe { &mut *value })
+ }
+
+ /// Tries to immutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get<'stack>(
+ &'stack self,
+ _proof: &'stack StackToken,
+ ) -> Result<&'stack T, InvalidThreadAccess> {
+ if self.is_valid() {
+ Ok(self.with_value(|value| unsafe { &*value }))
+ } else {
+ Err(InvalidThreadAccess)
+ }
+ }
+
+ /// Tries to mutably borrow the wrapped value.
+ ///
+ /// Returns `None` if the calling thread is not the one that wrapped the value.
+ pub fn try_get_mut<'stack>(
+ &'stack mut self,
+ _proof: &'stack StackToken,
+ ) -> Result<&'stack mut T, InvalidThreadAccess> {
+ if self.is_valid() {
+ Ok(self.with_value(|value| unsafe { &mut *value }))
+ } else {
+ Err(InvalidThreadAccess)
+ }
+ }
+}
+
+impl<T> From<T> for Sticky<T> {
+ #[inline]
+ fn from(t: T) -> Sticky<T> {
+ Sticky::new(t)
+ }
+}
+
+impl<T: Clone> Clone for Sticky<T> {
+ #[inline]
+ fn clone(&self) -> Sticky<T> {
+ crate::stack_token!(tok);
+ Sticky::new(self.get(tok).clone())
+ }
+}
+
+impl<T: Default> Default for Sticky<T> {
+ #[inline]
+ fn default() -> Sticky<T> {
+ Sticky::new(T::default())
+ }
+}
+
+impl<T: PartialEq> PartialEq for Sticky<T> {
+ #[inline]
+ fn eq(&self, other: &Sticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) == *other.get(tok)
+ }
+}
+
+impl<T: Eq> Eq for Sticky<T> {}
+
+impl<T: PartialOrd> PartialOrd for Sticky<T> {
+ #[inline]
+ fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
+ crate::stack_token!(tok);
+ self.get(tok).partial_cmp(other.get(tok))
+ }
+
+ #[inline]
+ fn lt(&self, other: &Sticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) < *other.get(tok)
+ }
+
+ #[inline]
+ fn le(&self, other: &Sticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) <= *other.get(tok)
+ }
+
+ #[inline]
+ fn gt(&self, other: &Sticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) > *other.get(tok)
+ }
+
+ #[inline]
+ fn ge(&self, other: &Sticky<T>) -> bool {
+ crate::stack_token!(tok);
+ *self.get(tok) >= *other.get(tok)
+ }
+}
+
+impl<T: Ord> Ord for Sticky<T> {
+ #[inline]
+ fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
+ crate::stack_token!(tok);
+ self.get(tok).cmp(other.get(tok))
+ }
+}
+
+impl<T: fmt::Display> fmt::Display for Sticky<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ crate::stack_token!(tok);
+ fmt::Display::fmt(self.get(tok), f)
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+ crate::stack_token!(tok);
+ match self.try_get(tok) {
+ Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
+ Err(..) => {
+ struct InvalidPlaceholder;
+ impl fmt::Debug for InvalidPlaceholder {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("<invalid thread>")
+ }
+ }
+
+ f.debug_struct("Sticky")
+ .field("value", &InvalidPlaceholder)
+ .finish()
+ }
+ }
+ }
+}
+
+// similar as for fragile ths type is sync because it only accesses TLS data
+// which is thread local. There is nothing that needs to be synchronized.
+unsafe impl<T> Sync for Sticky<T> {}
+
+// The entire point of this type is to be Send
+unsafe impl<T> Send for Sticky<T> {}
+
+#[test]
+fn test_basic() {
+ use std::thread;
+ let val = Sticky::new(true);
+ crate::stack_token!(tok);
+ assert_eq!(val.to_string(), "true");
+ assert_eq!(val.get(tok), &true);
+ assert!(val.try_get(tok).is_ok());
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ assert!(val.try_get(tok).is_err());
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_mut() {
+ let mut val = Sticky::new(true);
+ crate::stack_token!(tok);
+ *val.get_mut(tok) = false;
+ assert_eq!(val.to_string(), "false");
+ assert_eq!(val.get(tok), &false);
+}
+
+#[test]
+#[should_panic]
+fn test_access_other_thread() {
+ use std::thread;
+ let val = Sticky::new(true);
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ val.get(tok);
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_drop_same_thread() {
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::Arc;
+ let was_called = Arc::new(AtomicBool::new(false));
+ struct X(Arc<AtomicBool>);
+ impl Drop for X {
+ fn drop(&mut self) {
+ self.0.store(true, Ordering::SeqCst);
+ }
+ }
+ let val = Sticky::new(X(was_called.clone()));
+ mem::drop(val);
+ assert!(was_called.load(Ordering::SeqCst));
+}
+
+#[test]
+fn test_noop_drop_elsewhere() {
+ use std::sync::atomic::{AtomicBool, Ordering};
+ use std::sync::Arc;
+ use std::thread;
+
+ let was_called = Arc::new(AtomicBool::new(false));
+
+ {
+ let was_called = was_called.clone();
+ thread::spawn(move || {
+ struct X(Arc<AtomicBool>);
+ impl Drop for X {
+ fn drop(&mut self) {
+ self.0.store(true, Ordering::SeqCst);
+ }
+ }
+
+ let val = Sticky::new(X(was_called.clone()));
+ assert!(thread::spawn(move || {
+ // moves it here but do not deallocate
+ crate::stack_token!(tok);
+ val.try_get(tok).ok();
+ })
+ .join()
+ .is_ok());
+
+ assert!(!was_called.load(Ordering::SeqCst));
+ })
+ .join()
+ .unwrap();
+ }
+
+ assert!(was_called.load(Ordering::SeqCst));
+}
+
+#[test]
+fn test_rc_sending() {
+ use std::rc::Rc;
+ use std::thread;
+ let val = Sticky::new(Rc::new(true));
+ thread::spawn(move || {
+ crate::stack_token!(tok);
+ assert!(val.try_get(tok).is_err());
+ })
+ .join()
+ .unwrap();
+}
+
+#[test]
+fn test_two_stickies() {
+ struct Wat;
+
+ impl Drop for Wat {
+ fn drop(&mut self) {
+ // do nothing
+ }
+ }
+
+ let s1 = Sticky::new(Wat);
+ let s2 = Sticky::new(Wat);
+
+ // make sure all is well
+
+ drop(s1);
+ drop(s2);
+}
diff --git a/src/thread_id.rs b/src/thread_id.rs
new file mode 100644
index 0000000..00468b2
--- /dev/null
+++ b/src/thread_id.rs
@@ -0,0 +1,12 @@
+use std::num::NonZeroUsize;
+use std::sync::atomic::{AtomicUsize, Ordering};
+
+fn next() -> NonZeroUsize {
+ static COUNTER: AtomicUsize = AtomicUsize::new(1);
+ NonZeroUsize::new(COUNTER.fetch_add(1, Ordering::SeqCst)).expect("more than usize::MAX threads")
+}
+
+pub(crate) fn get() -> NonZeroUsize {
+ thread_local!(static THREAD_ID: NonZeroUsize = next());
+ THREAD_ID.with(|&x| x)
+}