diff options
Diffstat (limited to 'src/task/join_map.rs')
-rw-r--r-- | src/task/join_map.rs | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/src/task/join_map.rs b/src/task/join_map.rs index c6bf5bc..1fbe274 100644 --- a/src/task/join_map.rs +++ b/src/task/join_map.rs @@ -5,6 +5,7 @@ use std::collections::hash_map::RandomState; use std::fmt; use std::future::Future; use std::hash::{BuildHasher, Hash, Hasher}; +use std::marker::PhantomData; use tokio::runtime::Handle; use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; @@ -316,6 +317,60 @@ where self.insert(key, task); } + /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided + /// key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// Note that blocking tasks cannot be cancelled after execution starts. + /// Replaced blocking tasks will still run to completion if the task has begun + /// to execute when it is replaced. A blocking task which is replaced before + /// it has been scheduled on a blocking worker thread will be cancelled. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_blocking<F>(&mut self, key: K, f: F) + where + F: FnOnce() -> V, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_blocking(f); + self.insert(key, task) + } + + /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// Note that blocking tasks cannot be cancelled after execution starts. + /// Replaced blocking tasks will still run to completion if the task has begun + /// to execute when it is replaced. A blocking task which is replaced before + /// it has been scheduled on a blocking worker thread will be cancelled. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) + where + F: FnOnce() -> V, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_blocking_on(f, handle); + self.insert(key, task); + } + /// Spawn the provided task on the current [`LocalSet`] and store it in this /// `JoinMap` with the provided key. /// @@ -572,6 +627,19 @@ where } } + /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order. + /// + /// If a task has completed, but its output hasn't yet been consumed by a + /// call to [`join_next`], this method will still return its key. + /// + /// [`join_next`]: fn@Self::join_next + pub fn keys(&self) -> JoinMapKeys<'_, K, V> { + JoinMapKeys { + iter: self.tasks_by_key.keys(), + _value: PhantomData, + } + } + /// Returns `true` if this `JoinMap` contains a task for the provided key. /// /// If the task has completed, but its output hasn't yet been consumed by a @@ -805,3 +873,32 @@ impl<K: PartialEq> PartialEq for Key<K> { } impl<K: Eq> Eq for Key<K> {} + +/// An iterator over the keys of a [`JoinMap`]. +#[derive(Debug, Clone)] +pub struct JoinMapKeys<'a, K, V> { + iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>, + /// To make it easier to change JoinMap in the future, keep V as a generic + /// parameter. + _value: PhantomData<&'a V>, +} + +impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { + type Item = &'a K; + + fn next(&mut self) -> Option<&'a K> { + self.iter.next().map(|key| &key.key) + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.iter.size_hint() + } +} + +impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> { + fn len(&self) -> usize { + self.iter.len() + } +} + +impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {} |