diff options
Diffstat (limited to 'src/task.rs')
-rw-r--r-- | src/task.rs | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/src/task.rs b/src/task.rs new file mode 100644 index 0000000..fa98bae --- /dev/null +++ b/src/task.rs @@ -0,0 +1,253 @@ +//! Futures task based helpers + +#![allow(clippy::mutex_atomic)] + +use std::future::Future; +use std::mem; +use std::ops; +use std::pin::Pin; +use std::sync::{Arc, Condvar, Mutex}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + +use tokio_stream::Stream; + +/// TODO: dox +pub fn spawn<T>(task: T) -> Spawn<T> { + Spawn { + task: MockTask::new(), + future: Box::pin(task), + } +} + +/// Future spawned on a mock task +#[derive(Debug)] +pub struct Spawn<T> { + task: MockTask, + future: Pin<Box<T>>, +} + +/// Mock task +/// +/// A mock task is able to intercept and track wake notifications. +#[derive(Debug, Clone)] +struct MockTask { + waker: Arc<ThreadWaker>, +} + +#[derive(Debug)] +struct ThreadWaker { + state: Mutex<usize>, + condvar: Condvar, +} + +const IDLE: usize = 0; +const WAKE: usize = 1; +const SLEEP: usize = 2; + +impl<T> Spawn<T> { + /// Consumes `self` returning the inner value + pub fn into_inner(self) -> T + where + T: Unpin, + { + *Pin::into_inner(self.future) + } + + /// Returns `true` if the inner future has received a wake notification + /// since the last call to `enter`. + pub fn is_woken(&self) -> bool { + self.task.is_woken() + } + + /// Returns the number of references to the task waker + /// + /// The task itself holds a reference. The return value will never be zero. + pub fn waker_ref_count(&self) -> usize { + self.task.waker_ref_count() + } + + /// Enter the task context + pub fn enter<F, R>(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R, + { + let fut = self.future.as_mut(); + self.task.enter(|cx| f(cx, fut)) + } +} + +impl<T: Unpin> ops::Deref for Spawn<T> { + type Target = T; + + fn deref(&self) -> &T { + &self.future + } +} + +impl<T: Unpin> ops::DerefMut for Spawn<T> { + fn deref_mut(&mut self) -> &mut T { + &mut self.future + } +} + +impl<T: Future> Spawn<T> { + /// Polls a future + pub fn poll(&mut self) -> Poll<T::Output> { + let fut = self.future.as_mut(); + self.task.enter(|cx| fut.poll(cx)) + } +} + +impl<T: Stream> Spawn<T> { + /// Polls a stream + pub fn poll_next(&mut self) -> Poll<Option<T::Item>> { + let stream = self.future.as_mut(); + self.task.enter(|cx| stream.poll_next(cx)) + } +} + +impl<T: Future> Future for Spawn<T> { + type Output = T::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.future.as_mut().poll(cx) + } +} + +impl<T: Stream> Stream for Spawn<T> { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.future.as_mut().poll_next(cx) + } +} + +impl MockTask { + /// Creates new mock task + fn new() -> Self { + MockTask { + waker: Arc::new(ThreadWaker::new()), + } + } + + /// Runs a closure from the context of the task. + /// + /// Any wake notifications resulting from the execution of the closure are + /// tracked. + fn enter<F, R>(&mut self, f: F) -> R + where + F: FnOnce(&mut Context<'_>) -> R, + { + self.waker.clear(); + let waker = self.waker(); + let mut cx = Context::from_waker(&waker); + + f(&mut cx) + } + + /// Returns `true` if the inner future has received a wake notification + /// since the last call to `enter`. + fn is_woken(&self) -> bool { + self.waker.is_woken() + } + + /// Returns the number of references to the task waker + /// + /// The task itself holds a reference. The return value will never be zero. + fn waker_ref_count(&self) -> usize { + Arc::strong_count(&self.waker) + } + + fn waker(&self) -> Waker { + unsafe { + let raw = to_raw(self.waker.clone()); + Waker::from_raw(raw) + } + } +} + +impl Default for MockTask { + fn default() -> Self { + Self::new() + } +} + +impl ThreadWaker { + fn new() -> Self { + ThreadWaker { + state: Mutex::new(IDLE), + condvar: Condvar::new(), + } + } + + /// Clears any previously received wakes, avoiding potential spurrious + /// wake notifications. This should only be called immediately before running the + /// task. + fn clear(&self) { + *self.state.lock().unwrap() = IDLE; + } + + fn is_woken(&self) -> bool { + match *self.state.lock().unwrap() { + IDLE => false, + WAKE => true, + _ => unreachable!(), + } + } + + fn wake(&self) { + // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. + let mut state = self.state.lock().unwrap(); + let prev = *state; + + if prev == WAKE { + return; + } + + *state = WAKE; + + if prev == IDLE { + return; + } + + // The other half is sleeping, so we wake it up. + assert_eq!(prev, SLEEP); + self.condvar.notify_one(); + } +} + +static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); + +unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker { + RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) +} + +unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> { + Arc::from_raw(raw as *const ThreadWaker) +} + +unsafe fn clone(raw: *const ()) -> RawWaker { + let waker = from_raw(raw); + + // Increment the ref count + mem::forget(waker.clone()); + + to_raw(waker) +} + +unsafe fn wake(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); + + // We don't actually own a reference to the unparker + mem::forget(waker); +} + +unsafe fn drop_waker(raw: *const ()) { + let _ = from_raw(raw); +} |