diff options
Diffstat (limited to 'src/task/spawn_pinned.rs')
-rw-r--r-- | src/task/spawn_pinned.rs | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/src/task/spawn_pinned.rs b/src/task/spawn_pinned.rs new file mode 100644 index 0000000..b4102ec --- /dev/null +++ b/src/task/spawn_pinned.rs @@ -0,0 +1,436 @@ +use futures_util::future::{AbortHandle, Abortable}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::runtime::Builder; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::{spawn_local, JoinHandle, LocalSet}; + +/// A cloneable handle to a local pool, used for spawning `!Send` tasks. +/// +/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread +/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will +/// execute on the same thread) inside the Future you supply to the various spawn methods +/// of `LocalPoolHandle`, +/// +/// [`tokio::task::LocalSet`]: tokio::task::LocalSet +/// [`tokio::task::spawn_local`]: tokio::task::spawn_local +/// +/// # Examples +/// +/// ``` +/// use std::rc::Rc; +/// use tokio::{self, task }; +/// use tokio_util::task::LocalPoolHandle; +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let pool = LocalPoolHandle::new(5); +/// +/// let output = pool.spawn_pinned(|| { +/// // `data` is !Send + !Sync +/// let data = Rc::new("local data"); +/// let data_clone = data.clone(); +/// +/// async move { +/// task::spawn_local(async move { +/// println!("{}", data_clone); +/// }); +/// +/// data.to_string() +/// } +/// }).await.unwrap(); +/// println!("output: {}", output); +/// } +/// ``` +/// +#[derive(Clone)] +pub struct LocalPoolHandle { + pool: Arc<LocalPool>, +} + +impl LocalPoolHandle { + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this + /// pool via [`LocalPoolHandle::spawn_pinned`]. + /// + /// # Panics + /// + /// Panics if the pool size is less than one. + #[track_caller] + pub fn new(pool_size: usize) -> LocalPoolHandle { + assert!(pool_size > 0); + + let workers = (0..pool_size) + .map(|_| LocalWorkerHandle::new_worker()) + .collect(); + + let pool = Arc::new(LocalPool { workers }); + + LocalPoolHandle { pool } + } + + /// Returns the number of threads of the Pool. + #[inline] + pub fn num_threads(&self) -> usize { + self.pool.workers.len() + } + + /// Returns the number of tasks scheduled on each worker. The indices of the + /// worker threads correspond to the indices of the returned `Vec`. + pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> { + self.pool + .workers + .iter() + .map(|worker| worker.task_count.load(Ordering::SeqCst)) + .collect::<Vec<_>>() + } + + /// Spawn a task onto a worker thread and pin it there so it can't be moved + /// off of the thread. Note that the future is not [`Send`], but the + /// [`FnOnce`] which creates it is. + /// + /// # Examples + /// ``` + /// use std::rc::Rc; + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create the local pool + /// let pool = LocalPoolHandle::new(1); + /// + /// // Spawn a !Send future onto the pool and await it + /// let output = pool + /// .spawn_pinned(|| { + /// // Rc is !Send + !Sync + /// let local_data = Rc::new("test"); + /// + /// // This future holds an Rc, so it is !Send + /// async move { local_data.to_string() } + /// }) + /// .await + /// .unwrap(); + /// + /// assert_eq!(output, "test"); + /// } + /// ``` + pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool + .spawn_pinned(create_task, WorkerChoice::LeastBurdened) + } + + /// Differs from `spawn_pinned` only in that you can choose a specific worker thread + /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest + /// number of tasks scheduled. + /// + /// A worker thread is chosen by index. Indices are 0 based and the largest index + /// is given by `num_threads() - 1` + /// + /// # Panics + /// + /// This method panics if the index is out of bounds. + /// + /// # Examples + /// + /// This method can be used to spawn a task on all worker threads of the pool: + /// + /// ``` + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// const NUM_WORKERS: usize = 3; + /// let pool = LocalPoolHandle::new(NUM_WORKERS); + /// let handles = (0..pool.num_threads()) + /// .map(|worker_idx| { + /// pool.spawn_pinned_by_idx( + /// || { + /// async { + /// "test" + /// } + /// }, + /// worker_idx, + /// ) + /// }) + /// .collect::<Vec<_>>(); + /// + /// for handle in handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + #[track_caller] + pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool + .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) + } +} + +impl Debug for LocalPoolHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("LocalPoolHandle") + } +} + +enum WorkerChoice { + LeastBurdened, + ByIdx(usize), +} + +struct LocalPool { + workers: Vec<LocalWorkerHandle>, +} + +impl LocalPool { + /// Spawn a `?Send` future onto a worker + #[track_caller] + fn spawn_pinned<F, Fut>( + &self, + create_task: F, + worker_choice: WorkerChoice, + ) -> JoinHandle<Fut::Output> + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + let (worker, job_guard) = match worker_choice { + WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), + WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), + }; + let worker_spawner = worker.spawner.clone(); + + // Spawn a future onto the worker's runtime so we can immediately return + // a join handle. + worker.runtime_handle.spawn(async move { + // Move the job guard into the task + let _job_guard = job_guard; + + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + + // Inside the future we can't run spawn_local yet because we're not + // in the context of a LocalSet. We need to send create_task to the + // LocalSet task for spawning. + let spawn_task = Box::new(move || { + // Once we're in the LocalSet context we can call spawn_local + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); + + // Send the join handle back to the spawner. If sending fails, + // we assume the parent task was canceled, so cancel this task + // as well. + if let Err(join_handle) = sender.send(join_handle) { + join_handle.abort() + } + }); + + // Send the callback to the LocalSet task + if let Err(e) = worker_spawner.send(spawn_task) { + // Propagate the error as a panic in the join handle. + panic!("Failed to send job to worker: {}", e); + } + + // Wait for the task's join handle + let join_handle = match receiver.await { + Ok(handle) => handle, + Err(e) => { + // We sent the task successfully, but failed to get its + // join handle... We assume something happened to the worker + // and the task was not spawned. Propagate the error as a + // panic in the join handle. + panic!("Worker failed to send join handle: {}", e); + } + }; + + // Wait for the task to complete + let join_result = join_handle.await; + + match join_result { + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else if e.is_cancelled() { + // No one else should have the join handle, so this is + // unexpected. Forward this error as a panic in the join + // handle. + panic!("spawn_pinned task was canceled: {}", e); + } else { + // Something unknown happened (not a panic or + // cancellation). Forward this error as a panic in the + // join handle. + panic!("spawn_pinned task failed: {}", e); + } + } + } + }) + } + + /// Find the worker with the least number of tasks, increment its task + /// count, and return its handle. Make sure to actually spawn a task on + /// the worker so the task count is kept consistent with load. + /// + /// A job count guard is also returned to ensure the task count gets + /// decremented when the job is done. + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { + loop { + let (worker, task_count) = self + .workers + .iter() + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) + .min_by_key(|&(_, count)| count) + .expect("There must be more than one worker"); + + // Make sure the task count hasn't changed since when we choose this + // worker. Otherwise, restart the search. + if worker + .task_count + .compare_exchange( + task_count, + task_count + 1, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); + } + } + } + + #[track_caller] + fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { + let worker = &self.workers[idx]; + worker.task_count.fetch_add(1, Ordering::SeqCst); + + (worker, JobCountGuard(Arc::clone(&worker.task_count))) + } +} + +/// Automatically decrements a worker's job count when a job finishes (when +/// this gets dropped). +struct JobCountGuard(Arc<AtomicUsize>); + +impl Drop for JobCountGuard { + fn drop(&mut self) { + // Decrement the job count + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); + debug_assert!(previous_value >= 1); + } +} + +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); + +impl Drop for AbortGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>; + +struct LocalWorkerHandle { + runtime_handle: tokio::runtime::Handle, + spawner: UnboundedSender<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, +} + +impl LocalWorkerHandle { + /// Create a new worker for executing pinned tasks + fn new_worker() -> LocalWorkerHandle { + let (sender, receiver) = unbounded_channel(); + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to start a pinned worker thread runtime"); + let runtime_handle = runtime.handle().clone(); + let task_count = Arc::new(AtomicUsize::new(0)); + let task_count_clone = Arc::clone(&task_count); + + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + + LocalWorkerHandle { + runtime_handle, + spawner: sender, + task_count, + } + } + + fn run( + runtime: tokio::runtime::Runtime, + mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>, + task_count: Arc<AtomicUsize>, + ) { + let local_set = LocalSet::new(); + local_set.block_on(&runtime, async { + while let Some(spawn_task) = task_receiver.recv().await { + // Calls spawn_local(future) + (spawn_task)(); + } + }); + + // If there are any tasks on the runtime associated with a LocalSet task + // that has already completed, but whose output has not yet been + // reported, let that task complete. + // + // Since the task_count is decremented when the runtime task exits, + // reading that counter lets us know if any such tasks completed during + // the call to `block_on`. + // + // Tasks on the LocalSet can't complete during this loop since they're + // stored on the LocalSet and we aren't accessing it. + let mut previous_task_count = task_count.load(Ordering::SeqCst); + loop { + // This call will also run tasks spawned on the runtime. + runtime.block_on(tokio::task::yield_now()); + let new_task_count = task_count.load(Ordering::SeqCst); + if new_task_count == previous_task_count { + break; + } else { + previous_task_count = new_task_count; + } + } + + // It's now no longer possible for a task on the runtime to be + // associated with a LocalSet task that has completed. Drop both the + // LocalSet and runtime to let tasks on the runtime be cancelled if and + // only if they are still on the LocalSet. + // + // Drop the LocalSet task first so that anyone awaiting the runtime + // JoinHandle will see the cancelled error after the LocalSet task + // destructor has completed. + drop(local_set); + drop(runtime); + } +} |