aboutsummaryrefslogtreecommitdiff
path: root/src/task.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/task.rs')
-rw-r--r--src/task.rs253
1 files changed, 253 insertions, 0 deletions
diff --git a/src/task.rs b/src/task.rs
new file mode 100644
index 0000000..fa98bae
--- /dev/null
+++ b/src/task.rs
@@ -0,0 +1,253 @@
+//! Futures task based helpers
+
+#![allow(clippy::mutex_atomic)]
+
+use std::future::Future;
+use std::mem;
+use std::ops;
+use std::pin::Pin;
+use std::sync::{Arc, Condvar, Mutex};
+use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
+
+use tokio_stream::Stream;
+
+/// TODO: dox
+pub fn spawn<T>(task: T) -> Spawn<T> {
+ Spawn {
+ task: MockTask::new(),
+ future: Box::pin(task),
+ }
+}
+
+/// Future spawned on a mock task
+#[derive(Debug)]
+pub struct Spawn<T> {
+ task: MockTask,
+ future: Pin<Box<T>>,
+}
+
+/// Mock task
+///
+/// A mock task is able to intercept and track wake notifications.
+#[derive(Debug, Clone)]
+struct MockTask {
+ waker: Arc<ThreadWaker>,
+}
+
+#[derive(Debug)]
+struct ThreadWaker {
+ state: Mutex<usize>,
+ condvar: Condvar,
+}
+
+const IDLE: usize = 0;
+const WAKE: usize = 1;
+const SLEEP: usize = 2;
+
+impl<T> Spawn<T> {
+ /// Consumes `self` returning the inner value
+ pub fn into_inner(self) -> T
+ where
+ T: Unpin,
+ {
+ *Pin::into_inner(self.future)
+ }
+
+ /// Returns `true` if the inner future has received a wake notification
+ /// since the last call to `enter`.
+ pub fn is_woken(&self) -> bool {
+ self.task.is_woken()
+ }
+
+ /// Returns the number of references to the task waker
+ ///
+ /// The task itself holds a reference. The return value will never be zero.
+ pub fn waker_ref_count(&self) -> usize {
+ self.task.waker_ref_count()
+ }
+
+ /// Enter the task context
+ pub fn enter<F, R>(&mut self, f: F) -> R
+ where
+ F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
+ {
+ let fut = self.future.as_mut();
+ self.task.enter(|cx| f(cx, fut))
+ }
+}
+
+impl<T: Unpin> ops::Deref for Spawn<T> {
+ type Target = T;
+
+ fn deref(&self) -> &T {
+ &self.future
+ }
+}
+
+impl<T: Unpin> ops::DerefMut for Spawn<T> {
+ fn deref_mut(&mut self) -> &mut T {
+ &mut self.future
+ }
+}
+
+impl<T: Future> Spawn<T> {
+ /// Polls a future
+ pub fn poll(&mut self) -> Poll<T::Output> {
+ let fut = self.future.as_mut();
+ self.task.enter(|cx| fut.poll(cx))
+ }
+}
+
+impl<T: Stream> Spawn<T> {
+ /// Polls a stream
+ pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
+ let stream = self.future.as_mut();
+ self.task.enter(|cx| stream.poll_next(cx))
+ }
+}
+
+impl<T: Future> Future for Spawn<T> {
+ type Output = T::Output;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ self.future.as_mut().poll(cx)
+ }
+}
+
+impl<T: Stream> Stream for Spawn<T> {
+ type Item = T::Item;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.future.as_mut().poll_next(cx)
+ }
+}
+
+impl MockTask {
+ /// Creates new mock task
+ fn new() -> Self {
+ MockTask {
+ waker: Arc::new(ThreadWaker::new()),
+ }
+ }
+
+ /// Runs a closure from the context of the task.
+ ///
+ /// Any wake notifications resulting from the execution of the closure are
+ /// tracked.
+ fn enter<F, R>(&mut self, f: F) -> R
+ where
+ F: FnOnce(&mut Context<'_>) -> R,
+ {
+ self.waker.clear();
+ let waker = self.waker();
+ let mut cx = Context::from_waker(&waker);
+
+ f(&mut cx)
+ }
+
+ /// Returns `true` if the inner future has received a wake notification
+ /// since the last call to `enter`.
+ fn is_woken(&self) -> bool {
+ self.waker.is_woken()
+ }
+
+ /// Returns the number of references to the task waker
+ ///
+ /// The task itself holds a reference. The return value will never be zero.
+ fn waker_ref_count(&self) -> usize {
+ Arc::strong_count(&self.waker)
+ }
+
+ fn waker(&self) -> Waker {
+ unsafe {
+ let raw = to_raw(self.waker.clone());
+ Waker::from_raw(raw)
+ }
+ }
+}
+
+impl Default for MockTask {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl ThreadWaker {
+ fn new() -> Self {
+ ThreadWaker {
+ state: Mutex::new(IDLE),
+ condvar: Condvar::new(),
+ }
+ }
+
+ /// Clears any previously received wakes, avoiding potential spurrious
+ /// wake notifications. This should only be called immediately before running the
+ /// task.
+ fn clear(&self) {
+ *self.state.lock().unwrap() = IDLE;
+ }
+
+ fn is_woken(&self) -> bool {
+ match *self.state.lock().unwrap() {
+ IDLE => false,
+ WAKE => true,
+ _ => unreachable!(),
+ }
+ }
+
+ fn wake(&self) {
+ // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
+ let mut state = self.state.lock().unwrap();
+ let prev = *state;
+
+ if prev == WAKE {
+ return;
+ }
+
+ *state = WAKE;
+
+ if prev == IDLE {
+ return;
+ }
+
+ // The other half is sleeping, so we wake it up.
+ assert_eq!(prev, SLEEP);
+ self.condvar.notify_one();
+ }
+}
+
+static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
+
+unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
+ RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
+}
+
+unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
+ Arc::from_raw(raw as *const ThreadWaker)
+}
+
+unsafe fn clone(raw: *const ()) -> RawWaker {
+ let waker = from_raw(raw);
+
+ // Increment the ref count
+ mem::forget(waker.clone());
+
+ to_raw(waker)
+}
+
+unsafe fn wake(raw: *const ()) {
+ let waker = from_raw(raw);
+ waker.wake();
+}
+
+unsafe fn wake_by_ref(raw: *const ()) {
+ let waker = from_raw(raw);
+ waker.wake();
+
+ // We don't actually own a reference to the unparker
+ mem::forget(waker);
+}
+
+unsafe fn drop_waker(raw: *const ()) {
+ let _ = from_raw(raw);
+}