diff options
Diffstat (limited to 'src/runtime/task')
-rw-r--r-- | src/runtime/task/abort.rs | 87 | ||||
-rw-r--r-- | src/runtime/task/core.rs | 147 | ||||
-rw-r--r-- | src/runtime/task/error.rs | 33 | ||||
-rw-r--r-- | src/runtime/task/harness.rs | 279 | ||||
-rw-r--r-- | src/runtime/task/join.rs | 119 | ||||
-rw-r--r-- | src/runtime/task/list.rs | 10 | ||||
-rw-r--r-- | src/runtime/task/mod.rs | 211 | ||||
-rw-r--r-- | src/runtime/task/raw.rs | 179 | ||||
-rw-r--r-- | src/runtime/task/waker.rs | 88 |
9 files changed, 891 insertions, 262 deletions
diff --git a/src/runtime/task/abort.rs b/src/runtime/task/abort.rs new file mode 100644 index 0000000..6edca10 --- /dev/null +++ b/src/runtime/task/abort.rs @@ -0,0 +1,87 @@ +use crate::runtime::task::{Header, RawTask}; +use std::fmt; +use std::panic::{RefUnwindSafe, UnwindSafe}; + +/// An owned permission to abort a spawned task, without awaiting its completion. +/// +/// Unlike a [`JoinHandle`], an `AbortHandle` does *not* represent the +/// permission to await the task's completion, only to terminate it. +/// +/// The task may be aborted by calling the [`AbortHandle::abort`] method. +/// Dropping an `AbortHandle` releases the permission to terminate the task +/// --- it does *not* abort the task. +/// +/// [`JoinHandle`]: crate::task::JoinHandle +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub struct AbortHandle { + raw: RawTask, +} + +impl AbortHandle { + pub(super) fn new(raw: RawTask) -> Self { + Self { raw } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// If the task was already cancelled, such as by [`JoinHandle::abort`], + /// this method will do nothing. + /// + /// [cancelled]: method@super::error::JoinError::is_cancelled + /// [`JoinHandle::abort`]: method@super::JoinHandle::abort + pub fn abort(&self) { + self.raw.remote_abort(); + } + + /// Checks if the task associated with this `AbortHandle` has finished. + /// + /// Please note that this method can return `false` even if `abort` has been + /// called on the task. This is because the cancellation process may take + /// some time, and this method does not return `true` until it has + /// completed. + pub fn is_finished(&self) -> bool { + let state = self.raw.state().load(); + state.is_complete() + } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + // Safety: The header pointer is valid. + unsafe { Header::get_id(self.raw.header_ptr()) } + } +} + +unsafe impl Send for AbortHandle {} +unsafe impl Sync for AbortHandle {} + +impl UnwindSafe for AbortHandle {} +impl RefUnwindSafe for AbortHandle {} + +impl fmt::Debug for AbortHandle { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + // Safety: The header pointer is valid. + let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) }; + let id = unsafe { id_ptr.as_ref() }; + fmt.debug_struct("AbortHandle").field("id", id).finish() + } +} + +impl Drop for AbortHandle { + fn drop(&mut self) { + self.raw.drop_abort_handle(); + } +} diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs index 776e834..bcccc69 100644 --- a/src/runtime/task/core.rs +++ b/src/runtime/task/core.rs @@ -11,9 +11,10 @@ use crate::future::Future; use crate::loom::cell::UnsafeCell; +use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::Schedule; +use crate::runtime::task::{Id, Schedule}; use crate::util::linked_list; use std::pin::Pin; @@ -24,6 +25,9 @@ use std::task::{Context, Poll, Waker}; /// /// It is critical for `Header` to be the first field as the task structure will /// be referenced by both *mut Cell and *mut Header. +/// +/// Any changes to the layout of this struct _must_ also be reflected in the +/// const fns in raw.rs. #[repr(C)] pub(super) struct Cell<T: Future, S> { /// Hot task state data @@ -43,10 +47,17 @@ pub(super) struct CoreStage<T: Future> { /// The core of the task. /// /// Holds the future or output, depending on the stage of execution. +/// +/// Any changes to the layout of this struct _must_ also be reflected in the +/// const fns in raw.rs. +#[repr(C)] pub(super) struct Core<T: Future, S> { /// Scheduler used to drive this future. pub(super) scheduler: S, + /// The task's ID, used for populating `JoinError`s. + pub(super) task_id: Id, + /// Either the future or the output. pub(super) stage: CoreStage<T>, } @@ -57,8 +68,6 @@ pub(crate) struct Header { /// Task state. pub(super) state: State, - pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, - /// Pointer to next task, used with the injection queue. pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, @@ -80,18 +89,29 @@ pub(crate) struct Header { /// The tracing ID for this instrumented task. #[cfg(all(tokio_unstable, feature = "tracing"))] - pub(super) id: Option<tracing::Id>, + pub(super) tracing_id: Option<tracing::Id>, } unsafe impl Send for Header {} unsafe impl Sync for Header {} -/// Cold data is stored after the future. +/// Cold data is stored after the future. Data is considered cold if it is only +/// used during creation or shutdown of the task. pub(super) struct Trailer { + /// Pointers for the linked list in the `OwnedTasks` that owns this task. + pub(super) owned: linked_list::Pointers<Header>, /// Consumer task waiting on completion of this task. pub(super) waker: UnsafeCell<Option<Waker>>, } +generate_addr_of_methods! { + impl<> Trailer { + pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> { + &self.owned + } + } +} + /// Either the future or the output. pub(super) enum Stage<T: Future> { Running(T), @@ -102,29 +122,48 @@ pub(super) enum Stage<T: Future> { impl<T: Future, S: Schedule> Cell<T, S> { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { + pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> { #[cfg(all(tokio_unstable, feature = "tracing"))] - let id = future.id(); - Box::new(Cell { + let tracing_id = future.id(); + let result = Box::new(Cell { header: Header { state, - owned: UnsafeCell::new(linked_list::Pointers::new()), queue_next: UnsafeCell::new(None), vtable: raw::vtable::<T, S>(), owner_id: UnsafeCell::new(0), #[cfg(all(tokio_unstable, feature = "tracing"))] - id, + tracing_id, }, core: Core { scheduler, stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), }, + task_id, }, trailer: Trailer { waker: UnsafeCell::new(None), + owned: linked_list::Pointers::new(), }, - }) + }); + + #[cfg(debug_assertions)] + { + let trailer_addr = (&result.trailer) as *const Trailer as usize; + let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) }; + assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize); + + let scheduler_addr = (&result.core.scheduler) as *const S as usize; + let scheduler_ptr = + unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) }; + assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize); + + let id_addr = (&result.core.task_id) as *const Id as usize; + let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) }; + assert_eq!(id_addr, id_ptr.as_ptr() as usize); + } + + result } } @@ -132,7 +171,29 @@ impl<T: Future> CoreStage<T> { pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { self.stage.with_mut(f) } +} +/// Set and clear the task id in the context when the future is executed or +/// dropped, or when the output produced by the future is dropped. +pub(crate) struct TaskIdGuard { + parent_task_id: Option<Id>, +} + +impl TaskIdGuard { + fn enter(id: Id) -> Self { + TaskIdGuard { + parent_task_id: context::set_current_task_id(Some(id)), + } + } +} + +impl Drop for TaskIdGuard { + fn drop(&mut self) { + context::set_current_task_id(self.parent_task_id); + } +} + +impl<T: Future, S: Schedule> Core<T, S> { /// Polls the future. /// /// # Safety @@ -148,7 +209,7 @@ impl<T: Future> CoreStage<T> { /// heap. pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> { let res = { - self.stage.with_mut(|ptr| { + self.stage.stage.with_mut(|ptr| { // Safety: The caller ensures mutual exclusion to the field. let future = match unsafe { &mut *ptr } { Stage::Running(future) => future, @@ -158,6 +219,7 @@ impl<T: Future> CoreStage<T> { // Safety: The caller ensures the future is pinned. let future = unsafe { Pin::new_unchecked(future) }; + let _guard = TaskIdGuard::enter(self.task_id); future.poll(&mut cx) }) }; @@ -201,7 +263,7 @@ impl<T: Future> CoreStage<T> { pub(super) fn take_output(&self) -> super::Result<T::Output> { use std::mem; - self.stage.with_mut(|ptr| { + self.stage.stage.with_mut(|ptr| { // Safety:: the caller ensures mutual exclusion to the field. match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { Stage::Finished(output) => output, @@ -211,7 +273,8 @@ impl<T: Future> CoreStage<T> { } unsafe fn set_stage(&self, stage: Stage<T>) { - self.stage.with_mut(|ptr| *ptr = stage) + let _guard = TaskIdGuard::enter(self.task_id); + self.stage.stage.with_mut(|ptr| *ptr = stage) } } @@ -236,6 +299,62 @@ impl Header { // the safety requirements on `set_owner_id`. unsafe { self.owner_id.with(|ptr| *ptr) } } + + /// Gets a pointer to the `Trailer` of the task containing this `Header`. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> { + let offset = me.as_ref().vtable.trailer_offset; + let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>(); + NonNull::new_unchecked(trailer) + } + + /// Gets a pointer to the scheduler of the task containing this `Header`. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + /// + /// The generic type S must be set to the correct scheduler type for this + /// task. + pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> { + let offset = me.as_ref().vtable.scheduler_offset; + let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>(); + NonNull::new_unchecked(scheduler) + } + + /// Gets a pointer to the id of the task containing this `Header`. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> { + let offset = me.as_ref().vtable.id_offset; + let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>(); + NonNull::new_unchecked(id) + } + + /// Gets the id of the task containing this `Header`. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id { + let ptr = Header::get_id_ptr(me).as_ptr(); + *ptr + } + + /// Gets the tracing id of the task containing this `Header`. + /// + /// # Safety + /// + /// The provided raw pointer must point at the header of a task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> { + me.as_ref().tracing_id.as_ref() + } } impl Trailer { diff --git a/src/runtime/task/error.rs b/src/runtime/task/error.rs index 1a8129b..f7ead77 100644 --- a/src/runtime/task/error.rs +++ b/src/runtime/task/error.rs @@ -2,12 +2,13 @@ use std::any::Any; use std::fmt; use std::io; +use super::Id; use crate::util::SyncWrapper; - cfg_rt! { /// Task failed to execute to completion. pub struct JoinError { repr: Repr, + id: Id, } } @@ -17,15 +18,17 @@ enum Repr { } impl JoinError { - pub(crate) fn cancelled() -> JoinError { + pub(crate) fn cancelled(id: Id) -> JoinError { JoinError { repr: Repr::Cancelled, + id, } } - pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { + pub(crate) fn panic(id: Id, err: Box<dyn Any + Send + 'static>) -> JoinError { JoinError { repr: Repr::Panic(SyncWrapper::new(err)), + id, } } @@ -79,6 +82,7 @@ impl JoinError { /// } /// } /// ``` + #[track_caller] pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { self.try_into_panic() .expect("`JoinError` reason is not a panic.") @@ -111,13 +115,28 @@ impl JoinError { _ => Err(self), } } + + /// Returns a [task ID] that identifies the task which errored relative to + /// other currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> Id { + self.id + } } impl fmt::Display for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "cancelled"), - Repr::Panic(_) => write!(fmt, "panic"), + Repr::Cancelled => write!(fmt, "task {} was cancelled", self.id), + Repr::Panic(_) => write!(fmt, "task {} panicked", self.id), } } } @@ -125,8 +144,8 @@ impl fmt::Display for JoinError { impl fmt::Debug for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), - Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + Repr::Cancelled => write!(fmt, "JoinError::Cancelled({:?})", self.id), + Repr::Panic(_) => write!(fmt, "JoinError::Panic({:?}, ...)", self.id), } } } diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs index 0996e52..c079297 100644 --- a/src/runtime/task/harness.rs +++ b/src/runtime/task/harness.rs @@ -1,8 +1,8 @@ use crate::future::Future; -use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; -use crate::runtime::task::state::Snapshot; +use crate::runtime::task::core::{Cell, Core, Header, Trailer}; +use crate::runtime::task::state::{Snapshot, State}; use crate::runtime::task::waker::waker_ref; -use crate::runtime::task::{JoinError, Notified, Schedule, Task}; +use crate::runtime::task::{JoinError, Notified, RawTask, Schedule, Task}; use std::mem; use std::mem::ManuallyDrop; @@ -26,8 +26,16 @@ where } } + fn header_ptr(&self) -> NonNull<Header> { + self.cell.cast() + } + fn header(&self) -> &Header { - unsafe { &self.cell.as_ref().header } + unsafe { &*self.header_ptr().as_ptr() } + } + + fn state(&self) -> &State { + &self.header().state } fn trailer(&self) -> &Trailer { @@ -39,11 +47,102 @@ where } } +/// Task operations that can be implemented without being generic over the +/// scheduler or task. Only one version of these methods should exist in the +/// final binary. +impl RawTask { + pub(super) fn drop_reference(self) { + if self.state().ref_dec() { + self.dealloc(); + } + } + + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. + pub(super) fn wake_by_val(&self) { + use super::state::TransitionToNotifiedByVal; + + match self.state().transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.schedule(); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } + } + + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. + pub(super) fn wake_by_ref(&self) { + use super::state::TransitionToNotifiedByRef; + + match self.state().transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.schedule(); + } + TransitionToNotifiedByRef::DoNothing => {} + } + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(&self) { + if self.state().transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.schedule(); + } + } + + /// Try to set the waker notified when the task is complete. Returns true if + /// the task has already completed. If this call returns false, then the + /// waker will not be notified. + pub(super) fn try_set_join_waker(&self, waker: &Waker) -> bool { + can_read_output(self.header(), self.trailer(), waker) + } +} + impl<T, S> Harness<T, S> where T: Future, S: Schedule, { + pub(super) fn drop_reference(self) { + if self.state().ref_dec() { + self.dealloc(); + } + } + /// Polls the inner future. A ref-count is consumed. /// /// All necessary state checks and transitions are performed. @@ -91,32 +190,32 @@ where fn poll_inner(&self) -> PollFuture { use super::state::{TransitionToIdle, TransitionToRunning}; - match self.header().state.transition_to_running() { + match self.state().transition_to_running() { TransitionToRunning::Success => { - let waker_ref = waker_ref::<T, S>(self.header()); + let header_ptr = self.header_ptr(); + let waker_ref = waker_ref::<T, S>(&header_ptr); let cx = Context::from_waker(&*waker_ref); - let res = poll_future(&self.core().stage, cx); + let res = poll_future(self.core(), cx); if res == Poll::Ready(()) { // The future completed. Move on to complete the task. return PollFuture::Complete; } - match self.header().state.transition_to_idle() { + match self.state().transition_to_idle() { TransitionToIdle::Ok => PollFuture::Done, TransitionToIdle::OkNotified => PollFuture::Notified, TransitionToIdle::OkDealloc => PollFuture::Dealloc, TransitionToIdle::Cancelled => { // The transition to idle failed because the task was // cancelled during the poll. - - cancel_task(&self.core().stage); + cancel_task(self.core()); PollFuture::Complete } } } TransitionToRunning::Cancelled => { - cancel_task(&self.core().stage); + cancel_task(self.core()); PollFuture::Complete } TransitionToRunning::Failed => PollFuture::Done, @@ -131,7 +230,7 @@ where /// there is nothing further to do. When the task completes running, it will /// notice the `CANCELLED` bit and finalize the task. pub(super) fn shutdown(self) { - if !self.header().state.transition_to_shutdown() { + if !self.state().transition_to_shutdown() { // The task is concurrently running. No further work needed. self.drop_reference(); return; @@ -139,7 +238,7 @@ where // By transitioning the lifecycle to `Running`, we have permission to // drop the future. - cancel_task(&self.core().stage); + cancel_task(self.core()); self.complete(); } @@ -150,6 +249,19 @@ where // Check causality self.core().stage.with_mut(drop); + // Safety: The caller of this method just transitioned our ref-count to + // zero, so it is our responsibility to release the allocation. + // + // We don't hold any references into the allocation at this point, but + // it is possible for another thread to still hold a `&State` into the + // allocation if that other thread has decremented its last ref-count, + // but has not yet returned from the relevant method on `State`. + // + // However, the `State` type consists of just an `AtomicUsize`, and an + // `AtomicUsize` wraps the entirety of its contents in an `UnsafeCell`. + // As explained in the documentation for `UnsafeCell`, such references + // are allowed to be dangling after their last use, even if the + // reference has not yet gone out of scope. unsafe { drop(Box::from_raw(self.cell.as_ptr())); } @@ -160,122 +272,30 @@ where /// Read the task output into `dst`. pub(super) fn try_read_output(self, dst: &mut Poll<super::Result<T::Output>>, waker: &Waker) { if can_read_output(self.header(), self.trailer(), waker) { - *dst = Poll::Ready(self.core().stage.take_output()); + *dst = Poll::Ready(self.core().take_output()); } } pub(super) fn drop_join_handle_slow(self) { - let mut maybe_panic = None; - // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. - if self.header().state.unset_join_interested().is_err() { + if self.state().unset_join_interested().is_err() { // It is our responsibility to drop the output. This is critical as // the task output may not be `Send` and as such must remain with // the scheduler or `JoinHandle`. i.e. if the output remains in the // task structure until the task is deallocated, it may be dropped // by a Waker on any arbitrary thread. - let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { - self.core().stage.drop_future_or_output(); + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().drop_future_or_output(); })); - - if let Err(panic) = panic { - maybe_panic = Some(panic); - } } // Drop the `JoinHandle` reference, possibly deallocating the task self.drop_reference(); - - if let Some(panic) = maybe_panic { - panic::resume_unwind(panic); - } - } - - /// Remotely aborts the task. - /// - /// The caller should hold a ref-count, but we do not consume it. - /// - /// This is similar to `shutdown` except that it asks the runtime to perform - /// the shutdown. This is necessary to avoid the shutdown happening in the - /// wrong thread for non-Send tasks. - pub(super) fn remote_abort(self) { - if self.header().state.transition_to_notified_and_cancel() { - // The transition has created a new ref-count, which we turn into - // a Notified and pass to the task. - // - // Since the caller holds a ref-count, the task cannot be destroyed - // before the call to `schedule` returns even if the call drops the - // `Notified` internally. - self.core() - .scheduler - .schedule(Notified(self.get_new_task())); - } - } - - // ===== waker behavior ===== - - /// This call consumes a ref-count and notifies the task. This will create a - /// new Notified and submit it if necessary. - /// - /// The caller does not need to hold a ref-count besides the one that was - /// passed to this call. - pub(super) fn wake_by_val(self) { - use super::state::TransitionToNotifiedByVal; - - match self.header().state.transition_to_notified_by_val() { - TransitionToNotifiedByVal::Submit => { - // The caller has given us a ref-count, and the transition has - // created a new ref-count, so we now hold two. We turn the new - // ref-count Notified and pass it to the call to `schedule`. - // - // The old ref-count is retained for now to ensure that the task - // is not dropped during the call to `schedule` if the call - // drops the task it was given. - self.core() - .scheduler - .schedule(Notified(self.get_new_task())); - - // Now that we have completed the call to schedule, we can - // release our ref-count. - self.drop_reference(); - } - TransitionToNotifiedByVal::Dealloc => { - self.dealloc(); - } - TransitionToNotifiedByVal::DoNothing => {} - } - } - - /// This call notifies the task. It will not consume any ref-counts, but the - /// caller should hold a ref-count. This will create a new Notified and - /// submit it if necessary. - pub(super) fn wake_by_ref(&self) { - use super::state::TransitionToNotifiedByRef; - - match self.header().state.transition_to_notified_by_ref() { - TransitionToNotifiedByRef::Submit => { - // The transition above incremented the ref-count for a new task - // and the caller also holds a ref-count. The caller's ref-count - // ensures that the task is not destroyed even if the new task - // is dropped before `schedule` returns. - self.core() - .scheduler - .schedule(Notified(self.get_new_task())); - } - TransitionToNotifiedByRef::DoNothing => {} - } - } - - pub(super) fn drop_reference(self) { - if self.header().state.ref_dec() { - self.dealloc(); - } - } - - #[cfg(all(tokio_unstable, feature = "tracing"))] - pub(super) fn id(&self) -> Option<&tracing::Id> { - self.header().id.as_ref() } // ====== internal ====== @@ -285,7 +305,7 @@ where // The future has completed and its output has been written to the task // stage. We transition from running to complete. - let snapshot = self.header().state.transition_to_complete(); + let snapshot = self.state().transition_to_complete(); // We catch panics here in case dropping the future or waking the // JoinHandle panics. @@ -294,7 +314,7 @@ where // The `JoinHandle` is not interested in the output of // this task. It is our responsibility to drop the // output. - self.core().stage.drop_future_or_output(); + self.core().drop_future_or_output(); } else if snapshot.has_join_waker() { // Notify the join handle. The previous transition obtains the // lock on the waker cell. @@ -305,7 +325,7 @@ where // The task has completed execution and will no longer be scheduled. let num_release = self.release(); - if self.header().state.transition_to_terminal(num_release) { + if self.state().transition_to_terminal(num_release) { self.dealloc(); } } @@ -426,31 +446,31 @@ enum PollFuture { } /// Cancels the task and store the appropriate error in the stage field. -fn cancel_task<T: Future>(stage: &CoreStage<T>) { +fn cancel_task<T: Future, S: Schedule>(core: &Core<T, S>) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { - stage.drop_future_or_output(); + core.drop_future_or_output(); })); match res { Ok(()) => { - stage.store_output(Err(JoinError::cancelled())); + core.store_output(Err(JoinError::cancelled(core.task_id))); } Err(panic) => { - stage.store_output(Err(JoinError::panic(panic))); + core.store_output(Err(JoinError::panic(core.task_id, panic))); } } } /// Polls the future. If the future completes, the output is written to the /// stage field. -fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { +fn poll_future<T: Future, S: Schedule>(core: &Core<T, S>, cx: Context<'_>) -> Poll<()> { // Poll the future. let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { - struct Guard<'a, T: Future> { - core: &'a CoreStage<T>, + struct Guard<'a, T: Future, S: Schedule> { + core: &'a Core<T, S>, } - impl<'a, T: Future> Drop for Guard<'a, T> { + impl<'a, T: Future, S: Schedule> Drop for Guard<'a, T, S> { fn drop(&mut self) { // If the future panics on poll, we drop it inside the panic // guard. @@ -467,13 +487,20 @@ fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { let output = match output { Ok(Poll::Pending) => return Poll::Pending, Ok(Poll::Ready(output)) => Ok(output), - Err(panic) => Err(JoinError::panic(panic)), + Err(panic) => { + core.scheduler.unhandled_panic(); + Err(JoinError::panic(core.task_id, panic)) + } }; // Catch and ignore panics if the future panics on drop. - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { core.store_output(output); })); + if res.is_err() { + core.scheduler.unhandled_panic(); + } + Poll::Ready(()) } diff --git a/src/runtime/task/join.rs b/src/runtime/task/join.rs index 0abbff2..5660575 100644 --- a/src/runtime/task/join.rs +++ b/src/runtime/task/join.rs @@ -1,16 +1,19 @@ -use crate::runtime::task::RawTask; +use crate::runtime::task::{Header, RawTask}; use std::fmt; use std::future::Future; use std::marker::PhantomData; +use std::panic::{RefUnwindSafe, UnwindSafe}; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Waker}; cfg_rt! { /// An owned permission to join on a task (await its termination). /// - /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for - /// a task rather than a thread. + /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] + /// for a Tokio task rather than a thread. You do not need to `.await` the + /// `JoinHandle` to make the task execute — it will start running in the + /// background immediately. /// /// A `JoinHandle` *detaches* the associated task when it is dropped, which /// means that there is no longer any handle to the task, and no way to `join` @@ -19,6 +22,15 @@ cfg_rt! { /// This `struct` is created by the [`task::spawn`] and [`task::spawn_blocking`] /// functions. /// + /// # Cancel safety + /// + /// The `&mut JoinHandle<T>` type is cancel safe. If it is used as the event + /// in a `tokio::select!` statement and some other branch completes first, + /// then it is guaranteed that the output of the task is not lost. + /// + /// If a `JoinHandle` is dropped, then the task continues running in the + /// background and its return value is lost. + /// /// # Examples /// /// Creation from [`task::spawn`]: @@ -142,7 +154,7 @@ cfg_rt! { /// [`std::thread::JoinHandle`]: std::thread::JoinHandle /// [`JoinError`]: crate::task::JoinError pub struct JoinHandle<T> { - raw: Option<RawTask>, + raw: RawTask, _p: PhantomData<T>, } } @@ -150,10 +162,13 @@ cfg_rt! { unsafe impl<T: Send> Send for JoinHandle<T> {} unsafe impl<T: Send> Sync for JoinHandle<T> {} +impl<T> UnwindSafe for JoinHandle<T> {} +impl<T> RefUnwindSafe for JoinHandle<T> {} + impl<T> JoinHandle<T> { pub(super) fn new(raw: RawTask) -> JoinHandle<T> { JoinHandle { - raw: Some(raw), + raw, _p: PhantomData, } } @@ -192,10 +207,71 @@ impl<T> JoinHandle<T> { /// ``` /// [cancelled]: method@super::error::JoinError::is_cancelled pub fn abort(&self) { - if let Some(raw) = self.raw { - raw.remote_abort(); + self.raw.remote_abort(); + } + + /// Checks if the task associated with this `JoinHandle` has finished. + /// + /// Please note that this method can return `false` even if [`abort`] has been + /// called on the task. This is because the cancellation process may take + /// some time, and this method does not return `true` until it has + /// completed. + /// + /// ```rust + /// use tokio::time; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// # time::pause(); + /// let handle1 = tokio::spawn(async { + /// // do some stuff here + /// }); + /// let handle2 = tokio::spawn(async { + /// // do some other stuff here + /// time::sleep(time::Duration::from_secs(10)).await; + /// }); + /// // Wait for the task to finish + /// handle2.abort(); + /// time::sleep(time::Duration::from_secs(1)).await; + /// assert!(handle1.is_finished()); + /// assert!(handle2.is_finished()); + /// # } + /// ``` + /// [`abort`]: method@JoinHandle::abort + pub fn is_finished(&self) -> bool { + let state = self.raw.header().state.load(); + state.is_complete() + } + + /// Set the waker that is notified when the task completes. + pub(crate) fn set_join_waker(&mut self, waker: &Waker) { + if self.raw.try_set_join_waker(waker) { + // In this case the task has already completed. We wake the waker immediately. + waker.wake_by_ref(); } } + + /// Returns a new `AbortHandle` that can be used to remotely abort this task. + pub(crate) fn abort_handle(&self) -> super::AbortHandle { + self.raw.ref_inc(); + super::AbortHandle::new(self.raw) + } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + // Safety: The header pointer is valid. + unsafe { Header::get_id(self.raw.header_ptr()) } + } } impl<T> Unpin for JoinHandle<T> {} @@ -207,14 +283,7 @@ impl<T> Future for JoinHandle<T> { let mut ret = Poll::Pending; // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - // Raw should always be set. If it is not, this is due to polling after - // completion - let raw = self - .raw - .as_ref() - .expect("polling after `JoinHandle` already completed"); + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); // Try to read the task output. If the task is not yet complete, the // waker is stored and is notified once the task does complete. @@ -228,7 +297,8 @@ impl<T> Future for JoinHandle<T> { // // The type of `T` must match the task's output type. unsafe { - raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); + self.raw + .try_read_output(&mut ret as *mut _ as *mut (), cx.waker()); } if ret.is_ready() { @@ -241,13 +311,11 @@ impl<T> Future for JoinHandle<T> { impl<T> Drop for JoinHandle<T> { fn drop(&mut self) { - if let Some(raw) = self.raw.take() { - if raw.header().state.drop_join_handle_fast().is_ok() { - return; - } - - raw.drop_join_handle_slow(); + if self.raw.state().drop_join_handle_fast().is_ok() { + return; } + + self.raw.drop_join_handle_slow(); } } @@ -256,6 +324,9 @@ where T: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("JoinHandle").finish() + // Safety: The header pointer is valid. + let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) }; + let id = unsafe { id_ptr.as_ref() }; + fmt.debug_struct("JoinHandle").field("id", id).finish() } } diff --git a/src/runtime/task/list.rs b/src/runtime/task/list.rs index 7758f8d..159c13e 100644 --- a/src/runtime/task/list.rs +++ b/src/runtime/task/list.rs @@ -84,13 +84,14 @@ impl<S: 'static> OwnedTasks<S> { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access @@ -163,7 +164,7 @@ impl<S: 'static> OwnedTasks<S> { // safety: We just checked that the provided task is not in some other // linked list. - unsafe { self.inner.lock().list.remove(task.header().into()) } + unsafe { self.inner.lock().list.remove(task.header_ptr()) } } pub(crate) fn is_empty(&self) -> bool { @@ -187,13 +188,14 @@ impl<S: 'static> LocalOwnedTasks<S> { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access @@ -238,7 +240,7 @@ impl<S: 'static> LocalOwnedTasks<S> { self.with_inner(|inner| // safety: We just checked that the provided task is not in some // other linked list. - unsafe { inner.list.remove(task.header().into()) }) + unsafe { inner.list.remove(task.header_ptr()) }) } /// Asserts that the given task is owned by this LocalOwnedTasks and convert diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 1f18209..fea6e0f 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -25,7 +25,7 @@ //! //! The task uses a reference count to keep track of how many active references //! exist. The Unowned reference type takes up two ref-counts. All other -//! reference types take pu a single ref-count. +//! reference types take up a single ref-count. //! //! Besides the waker type, each task has at most one of each reference type. //! @@ -47,7 +47,8 @@ //! //! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. //! -//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! * JOIN_WAKER - Acts as an access control bit for the join handle waker. The +//! protocol for its usage is described below. //! //! The rest of the bits are used for the ref-count. //! @@ -71,10 +72,38 @@ //! a lock for the stage field, and it can be accessed only by the thread //! that set RUNNING to one. //! -//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the -//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the -//! thread that set COMPLETE to one has exclusive access to the join handle -//! waker. +//! * The waker field may be concurrently accessed by different threads: in one +//! thread the runtime may complete a task and *read* the waker field to +//! invoke the waker, and in another thread the task's JoinHandle may be +//! polled, and if the task hasn't yet completed, the JoinHandle may *write* +//! a waker to the waker field. The JOIN_WAKER bit ensures safe access by +//! multiple threads to the waker field using the following rules: +//! +//! 1. JOIN_WAKER is initialized to zero. +//! +//! 2. If JOIN_WAKER is zero, then the JoinHandle has exclusive (mutable) +//! access to the waker field. +//! +//! 3. If JOIN_WAKER is one, then the JoinHandle has shared (read-only) +//! access to the waker field. +//! +//! 4. If JOIN_WAKER is one and COMPLETE is one, then the runtime has shared +//! (read-only) access to the waker field. +//! +//! 5. If the JoinHandle needs to write to the waker field, then the +//! JoinHandle needs to (i) successfully set JOIN_WAKER to zero if it is +//! not already zero to gain exclusive access to the waker field per rule +//! 2, (ii) write a waker, and (iii) successfully set JOIN_WAKER to one. +//! +//! 6. The JoinHandle can change JOIN_WAKER only if COMPLETE is zero (i.e. +//! the task hasn't yet completed). +//! +//! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a +//! race. If step (i) fails, then the attempt to write a waker is aborted. If +//! step (iii) fails because COMPLETE is set to one by another thread after +//! step (i), then the waker field is cleared. Once COMPLETE is one (i.e. +//! task has completed), the JoinHandle will not modify JOIN_WAKER. After the +//! runtime sets COMPLETE to one, it invokes the waker if there is one. //! //! All other fields are immutable and can be accessed immutably without //! synchronization by anyone. @@ -121,7 +150,7 @@ //! 1. The output is created on the thread that the future was polled on. Since //! only non-Send futures can have non-Send output, the future was polled on //! the thread that the future was spawned from. -//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! 2. Since `JoinHandle<Output>` is not Send if Output is not Send, the //! JoinHandle is also on the thread that the future was spawned from. //! 3. Thus, the JoinHandle will not move the output across threads when it //! takes or drops the output. @@ -135,6 +164,12 @@ //! poll call will notice it when the poll finishes, and the task is cancelled //! at that point. +// Some task infrastructure is here to support `JoinSet`, which is currently +// unstable. This should be removed once `JoinSet` is stabilized. +#![cfg_attr(not(tokio_unstable), allow(dead_code))] + +use crate::runtime::context; + mod core; use self::core::Cell; use self::core::Header; @@ -151,7 +186,14 @@ cfg_rt_multi_thread! { pub(super) use self::inject::Inject; } +#[cfg(feature = "rt")] +mod abort; mod join; + +#[cfg(feature = "rt")] +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::abort::AbortHandle; + #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::join::JoinHandle; @@ -173,6 +215,70 @@ use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; +/// An opaque ID that uniquely identifies a task relative to all other currently +/// running tasks. +/// +/// # Notes +/// +/// - Task IDs are unique relative to other *currently running* tasks. When a +/// task completes, the same ID may be used for another task. +/// - Task IDs are *not* sequential, and do not indicate the order in which +/// tasks are spawned, what runtime a task is spawned on, or any other data. +/// - The task ID of the currently running task can be obtained from inside the +/// task via the [`task::try_id()`](crate::task::try_id()) and +/// [`task::id()`](crate::task::id()) functions and from outside the task via +/// the [`JoinHandle::id()`](crate::task::JoinHandle::id()) function. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] +pub struct Id(u64); + +/// Returns the [`Id`] of the currently running task. +/// +/// # Panics +/// +/// This function panics if called from outside a task. Please note that calls +/// to `block_on` do not have task IDs, so the method will panic if called from +/// within a call to `block_on`. For a version of this function that doesn't +/// panic, see [`task::try_id()`](crate::runtime::task::try_id()). +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [task ID]: crate::task::Id +/// [unstable]: crate#unstable-features +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn id() -> Id { + context::current_task_id().expect("Can't get a task id when not inside a task") +} + +/// Returns the [`Id`] of the currently running task, or `None` if called outside +/// of a task. +/// +/// This function is similar to [`task::id()`](crate::runtime::task::id()), except +/// that it returns `None` rather than panicking if called outside of a task +/// context. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [task ID]: crate::task::Id +/// [unstable]: crate#unstable-features +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn try_id() -> Option<Id> { + context::current_task_id() +} + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task<S: 'static> { @@ -230,6 +336,11 @@ pub(crate) trait Schedule: Sync + Sized + 'static { fn yield_now(&self, task: Notified<Self>) { self.schedule(task); } + + /// Polling the task resulted in a panic. Should the runtime shutdown? + fn unhandled_panic(&self) { + // By default, do nothing. This maintains the 1.0 behavior. + } } cfg_rt! { @@ -239,14 +350,15 @@ cfg_rt! { /// notification. fn new_task<T, S>( task: T, - scheduler: S + scheduler: S, + id: Id, ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let raw = RawTask::new::<T, S>(task, scheduler); + let raw = RawTask::new::<T, S>(task, scheduler, id); let task = Task { raw, _p: PhantomData, @@ -264,13 +376,13 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) + pub(crate) fn unowned<T, S>(task: T, scheduler: S, id: Id) -> (UnownedTask<S>, JoinHandle<T::Output>) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { - let (task, notified, join) = new_task(task, scheduler); + let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. // This is valid because an UnownedTask holds two ref-counts. @@ -296,6 +408,10 @@ impl<S: 'static> Task<S> { fn header(&self) -> &Header { self.raw.header() } + + fn header_ptr(&self) -> NonNull<Header> { + self.raw.header_ptr() + } } impl<S: 'static> Notified<S> { @@ -313,7 +429,7 @@ cfg_rt_multi_thread! { impl<S: 'static> Task<S> { fn into_raw(self) -> NonNull<Header> { - let ret = self.header().into(); + let ret = self.raw.header_ptr(); mem::forget(self); ret } @@ -327,7 +443,7 @@ cfg_rt_multi_thread! { } impl<S: Schedule> Task<S> { - /// Pre-emptively cancels the task as part of the shutdown process. + /// Preemptively cancels the task as part of the shutdown process. pub(crate) fn shutdown(self) { let raw = self.raw; mem::forget(self); @@ -347,6 +463,7 @@ impl<S: Schedule> LocalNotified<S> { impl<S: Schedule> UnownedTask<S> { // Used in test of the inject queue. #[cfg(test)] + #[cfg_attr(tokio_wasm, allow(dead_code))] pub(super) fn into_notified(self) -> Notified<S> { Notified(self.into_task()) } @@ -426,7 +543,7 @@ unsafe impl<S> linked_list::Link for Task<S> { type Target = Header; fn as_raw(handle: &Task<S>) -> NonNull<Header> { - handle.header().into() + handle.raw.header_ptr() } unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { @@ -434,7 +551,69 @@ unsafe impl<S> linked_list::Link for Task<S> { } unsafe fn pointers(target: NonNull<Header>) -> NonNull<linked_list::Pointers<Header>> { - // Not super great as it avoids some of looms checking... - NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) + self::core::Trailer::addr_of_owned(Header::get_trailer(target)) + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Id { + // When 64-bit atomics are available, use a static `AtomicU64` counter to + // generate task IDs. + // + // Note(eliza): we _could_ just use `crate::loom::AtomicU64`, which switches + // between an atomic and mutex-based implementation here, rather than having + // two separate functions for targets with and without 64-bit atomics. + // However, because we can't use the mutex-based implementation in a static + // initializer directly, the 32-bit impl also has to use a `OnceCell`, and I + // thought it was nicer to avoid the `OnceCell` overhead on 64-bit + // platforms... + cfg_has_atomic_u64! { + pub(crate) fn next() -> Self { + use std::sync::atomic::{AtomicU64, Ordering::Relaxed}; + static NEXT_ID: AtomicU64 = AtomicU64::new(1); + Self(NEXT_ID.fetch_add(1, Relaxed)) + } + } + + cfg_not_has_atomic_u64! { + cfg_has_const_mutex_new! { + pub(crate) fn next() -> Self { + use crate::loom::sync::Mutex; + static NEXT_ID: Mutex<u64> = Mutex::const_new(1); + + let mut lock = NEXT_ID.lock(); + let id = *lock; + *lock += 1; + Self(id) + } + } + + cfg_not_has_const_mutex_new! { + pub(crate) fn next() -> Self { + use crate::util::once_cell::OnceCell; + use crate::loom::sync::Mutex; + + fn init_next_id() -> Mutex<u64> { + Mutex::new(1) + } + + static NEXT_ID: OnceCell<Mutex<u64>> = OnceCell::new(); + + let next_id = NEXT_ID.get(init_next_id); + let mut lock = next_id.lock(); + let id = *lock; + *lock += 1; + Self(id) + } + } + } + + pub(crate) fn as_u64(&self) -> u64 { + self.0 } } diff --git a/src/runtime/task/raw.rs b/src/runtime/task/raw.rs index fbc9574..b9700ae 100644 --- a/src/runtime/task/raw.rs +++ b/src/runtime/task/raw.rs @@ -1,5 +1,6 @@ use crate::future::Future; -use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; +use crate::runtime::task::core::{Core, Trailer}; +use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -13,6 +14,9 @@ pub(super) struct Vtable { /// Polls the future. pub(super) poll: unsafe fn(NonNull<Header>), + /// Schedules the task for execution on the runtime. + pub(super) schedule: unsafe fn(NonNull<Header>), + /// Deallocates the memory. pub(super) dealloc: unsafe fn(NonNull<Header>), @@ -22,32 +26,142 @@ pub(super) struct Vtable { /// The join handle has been dropped. pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), - /// The task is remotely aborted. - pub(super) remote_abort: unsafe fn(NonNull<Header>), + /// An abort handle has been dropped. + pub(super) drop_abort_handle: unsafe fn(NonNull<Header>), /// Scheduler is being shutdown. pub(super) shutdown: unsafe fn(NonNull<Header>), + + /// The number of bytes that the `trailer` field is offset from the header. + pub(super) trailer_offset: usize, + + /// The number of bytes that the `scheduler` field is offset from the header. + pub(super) scheduler_offset: usize, + + /// The number of bytes that the `id` field is offset from the header. + pub(super) id_offset: usize, } /// Get the vtable for the requested `T` and `S` generics. pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { &Vtable { poll: poll::<T, S>, + schedule: schedule::<S>, dealloc: dealloc::<T, S>, try_read_output: try_read_output::<T, S>, drop_join_handle_slow: drop_join_handle_slow::<T, S>, - remote_abort: remote_abort::<T, S>, + drop_abort_handle: drop_abort_handle::<T, S>, shutdown: shutdown::<T, S>, + trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET, + scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET, + id_offset: OffsetHelper::<T, S>::ID_OFFSET, } } +/// Calling `get_trailer_offset` directly in vtable doesn't work because it +/// prevents the vtable from being promoted to a static reference. +/// +/// See this thread for more info: +/// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508> +struct OffsetHelper<T, S>(T, S); +impl<T: Future, S: Schedule> OffsetHelper<T, S> { + // Pass `size_of`/`align_of` as arguments rather than calling them directly + // inside `get_trailer_offset` because trait bounds on generic parameters + // of const fn are unstable on our MSRV. + const TRAILER_OFFSET: usize = get_trailer_offset( + std::mem::size_of::<Header>(), + std::mem::size_of::<Core<T, S>>(), + std::mem::align_of::<Core<T, S>>(), + std::mem::align_of::<Trailer>(), + ); + + // The `scheduler` is the first field of `Core`, so it has the same + // offset as `Core`. + const SCHEDULER_OFFSET: usize = get_core_offset( + std::mem::size_of::<Header>(), + std::mem::align_of::<Core<T, S>>(), + ); + + const ID_OFFSET: usize = get_id_offset( + std::mem::size_of::<Header>(), + std::mem::align_of::<Core<T, S>>(), + std::mem::size_of::<S>(), + std::mem::align_of::<Id>(), + ); +} + +/// Compute the offset of the `Trailer` field in `Cell<T, S>` using the +/// `#[repr(C)]` algorithm. +/// +/// Pseudo-code for the `#[repr(C)]` algorithm can be found here: +/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs> +const fn get_trailer_offset( + header_size: usize, + core_size: usize, + core_align: usize, + trailer_align: usize, +) -> usize { + let mut offset = header_size; + + let core_misalign = offset % core_align; + if core_misalign > 0 { + offset += core_align - core_misalign; + } + offset += core_size; + + let trailer_misalign = offset % trailer_align; + if trailer_misalign > 0 { + offset += trailer_align - trailer_misalign; + } + + offset +} + +/// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the +/// `#[repr(C)]` algorithm. +/// +/// Pseudo-code for the `#[repr(C)]` algorithm can be found here: +/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs> +const fn get_core_offset(header_size: usize, core_align: usize) -> usize { + let mut offset = header_size; + + let core_misalign = offset % core_align; + if core_misalign > 0 { + offset += core_align - core_misalign; + } + + offset +} + +/// Compute the offset of the `Id` field in `Cell<T, S>` using the +/// `#[repr(C)]` algorithm. +/// +/// Pseudo-code for the `#[repr(C)]` algorithm can be found here: +/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs> +const fn get_id_offset( + header_size: usize, + core_align: usize, + scheduler_size: usize, + id_align: usize, +) -> usize { + let mut offset = get_core_offset(header_size, core_align); + offset += scheduler_size; + + let id_misalign = offset % id_align; + if id_misalign > 0 { + offset += id_align - id_misalign; + } + + offset +} + impl RawTask { - pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask + pub(super) fn new<T, S>(task: T, scheduler: S, id: Id) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } @@ -57,19 +171,40 @@ impl RawTask { RawTask { ptr } } - /// Returns a reference to the task's meta structure. - /// - /// Safe as `Header` is `Sync`. + pub(super) fn header_ptr(&self) -> NonNull<Header> { + self.ptr + } + + pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> { + unsafe { Header::get_trailer(self.ptr) } + } + + /// Returns a reference to the task's header. pub(super) fn header(&self) -> &Header { unsafe { self.ptr.as_ref() } } + /// Returns a reference to the task's trailer. + pub(super) fn trailer(&self) -> &Trailer { + unsafe { &*self.trailer_ptr().as_ptr() } + } + + /// Returns a reference to the task's state. + pub(super) fn state(&self) -> &State { + &self.header().state + } + /// Safety: mutual exclusion is required to call this function. pub(super) fn poll(self) { let vtable = self.header().vtable; unsafe { (vtable.poll)(self.ptr) } } + pub(super) fn schedule(self) { + let vtable = self.header().vtable; + unsafe { (vtable.schedule)(self.ptr) } + } + pub(super) fn dealloc(self) { let vtable = self.header().vtable; unsafe { @@ -89,14 +224,21 @@ impl RawTask { unsafe { (vtable.drop_join_handle_slow)(self.ptr) } } + pub(super) fn drop_abort_handle(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_abort_handle)(self.ptr) } + } + pub(super) fn shutdown(self) { let vtable = self.header().vtable; unsafe { (vtable.shutdown)(self.ptr) } } - pub(super) fn remote_abort(self) { - let vtable = self.header().vtable; - unsafe { (vtable.remote_abort)(self.ptr) } + /// Increment the task's reference count. + /// + /// Currently, this is used only when creating an `AbortHandle`. + pub(super) fn ref_inc(self) { + self.header().state.ref_inc(); } } @@ -113,6 +255,15 @@ unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) { harness.poll(); } +unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) { + use crate::runtime::task::{Notified, Task}; + + let scheduler = Header::get_scheduler::<S>(ptr); + scheduler + .as_ref() + .schedule(Notified(Task::from_raw(ptr.cast()))); +} + unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) { let harness = Harness::<T, S>::from_raw(ptr); harness.dealloc(); @@ -134,9 +285,9 @@ unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { harness.drop_join_handle_slow() } -unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { +unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) { let harness = Harness::<T, S>::from_raw(ptr); - harness.remote_abort() + harness.drop_reference(); } unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { diff --git a/src/runtime/task/waker.rs b/src/runtime/task/waker.rs index b7313b4..b5f5ace 100644 --- a/src/runtime/task/waker.rs +++ b/src/runtime/task/waker.rs @@ -1,6 +1,5 @@ use crate::future::Future; -use crate::runtime::task::harness::Harness; -use crate::runtime::task::{Header, Schedule}; +use crate::runtime::task::{Header, RawTask, Schedule}; use std::marker::PhantomData; use std::mem::ManuallyDrop; @@ -13,9 +12,9 @@ pub(super) struct WakerRef<'a, S: 'static> { _p: PhantomData<(&'a Header, S)>, } -/// Returns a `WakerRef` which avoids having to pre-emptively increase the +/// Returns a `WakerRef` which avoids having to preemptively increase the /// refcount if there is no need to do so. -pub(super) fn waker_ref<T, S>(header: &Header) -> WakerRef<'_, S> +pub(super) fn waker_ref<T, S>(header: &NonNull<Header>) -> WakerRef<'_, S> where T: Future, S: Schedule, @@ -28,7 +27,7 @@ where // point and not an *owned* waker, we must ensure that `drop` is never // called on this waker instance. This is done by wrapping it with // `ManuallyDrop` and then never calling drop. - let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T, S>(header))) }; + let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker(*header))) }; WakerRef { waker, @@ -46,8 +45,8 @@ impl<S> ops::Deref for WakerRef<'_, S> { cfg_trace! { macro_rules! trace { - ($harness:expr, $op:expr) => { - if let Some(id) = $harness.id() { + ($header:expr, $op:expr) => { + if let Some(id) = Header::get_tracing_id(&$header) { tracing::trace!( target: "tokio::task::waker", op = $op, @@ -60,71 +59,46 @@ cfg_trace! { cfg_not_trace! { macro_rules! trace { - ($harness:expr, $op:expr) => { + ($header:expr, $op:expr) => { // noop - let _ = &$harness; + let _ = &$header; } } } -unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker -where - T: Future, - S: Schedule, -{ - let header = ptr as *const Header; - let ptr = NonNull::new_unchecked(ptr as *mut Header); - let harness = Harness::<T, S>::from_raw(ptr); - trace!(harness, "waker.clone"); - (*header).state.ref_inc(); - raw_waker::<T, S>(header) +unsafe fn clone_waker(ptr: *const ()) -> RawWaker { + let header = NonNull::new_unchecked(ptr as *mut Header); + trace!(header, "waker.clone"); + header.as_ref().state.ref_inc(); + raw_waker(header) } -unsafe fn drop_waker<T, S>(ptr: *const ()) -where - T: Future, - S: Schedule, -{ +unsafe fn drop_waker(ptr: *const ()) { let ptr = NonNull::new_unchecked(ptr as *mut Header); - let harness = Harness::<T, S>::from_raw(ptr); - trace!(harness, "waker.drop"); - harness.drop_reference(); + trace!(ptr, "waker.drop"); + let raw = RawTask::from_raw(ptr); + raw.drop_reference(); } -unsafe fn wake_by_val<T, S>(ptr: *const ()) -where - T: Future, - S: Schedule, -{ +unsafe fn wake_by_val(ptr: *const ()) { let ptr = NonNull::new_unchecked(ptr as *mut Header); - let harness = Harness::<T, S>::from_raw(ptr); - trace!(harness, "waker.wake"); - harness.wake_by_val(); + trace!(ptr, "waker.wake"); + let raw = RawTask::from_raw(ptr); + raw.wake_by_val(); } // Wake without consuming the waker -unsafe fn wake_by_ref<T, S>(ptr: *const ()) -where - T: Future, - S: Schedule, -{ +unsafe fn wake_by_ref(ptr: *const ()) { let ptr = NonNull::new_unchecked(ptr as *mut Header); - let harness = Harness::<T, S>::from_raw(ptr); - trace!(harness, "waker.wake_by_ref"); - harness.wake_by_ref(); + trace!(ptr, "waker.wake_by_ref"); + let raw = RawTask::from_raw(ptr); + raw.wake_by_ref(); } -fn raw_waker<T, S>(header: *const Header) -> RawWaker -where - T: Future, - S: Schedule, -{ - let ptr = header as *const (); - let vtable = &RawWakerVTable::new( - clone_waker::<T, S>, - wake_by_val::<T, S>, - wake_by_ref::<T, S>, - drop_waker::<T, S>, - ); - RawWaker::new(ptr, vtable) +static WAKER_VTABLE: RawWakerVTable = + RawWakerVTable::new(clone_waker, wake_by_val, wake_by_ref, drop_waker); + +fn raw_waker(header: NonNull<Header>) -> RawWaker { + let ptr = header.as_ptr() as *const (); + RawWaker::new(ptr, &WAKER_VTABLE) } |