diff options
author | Joel Galenson <jgalenson@google.com> | 2020-10-23 09:39:31 -0700 |
---|---|---|
committer | Joel Galenson <jgalenson@google.com> | 2020-10-23 09:52:09 -0700 |
commit | d5495b03381a3ebe0805db353d198b285b535b5c (patch) | |
tree | 778b8524d15fca8b73db0253ee0e1919d0848bb6 /src | |
parent | ba45c5bedf31df8562364c61d3dfb5262f10642e (diff) | |
download | tokio-d5495b03381a3ebe0805db353d198b285b535b5c.tar.gz |
Update to tokio-0.3.1 and add new features
Test: Build
Change-Id: I5b5b9b386a21982a019653d0cf0bd3afc505cfac
Diffstat (limited to 'src')
228 files changed, 11322 insertions, 13218 deletions
diff --git a/src/blocking.rs b/src/blocking.rs new file mode 100644 index 0000000..f88b1db --- /dev/null +++ b/src/blocking.rs @@ -0,0 +1,48 @@ +cfg_rt! { + pub(crate) use crate::runtime::spawn_blocking; + pub(crate) use crate::task::JoinHandle; +} + +cfg_not_rt! { + use std::fmt; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + pub(crate) fn spawn_blocking<F, R>(_f: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + assert_send_sync::<JoinHandle<std::cell::Cell<()>>>(); + panic!("requires the `rt` Tokio feature flag") + + } + + pub(crate) struct JoinHandle<R> { + _p: std::marker::PhantomData<R>, + } + + unsafe impl<T: Send> Send for JoinHandle<T> {} + unsafe impl<T: Send> Sync for JoinHandle<T> {} + + impl<R> Future for JoinHandle<R> { + type Output = Result<R, std::io::Error>; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { + unreachable!() + } + } + + impl<T> fmt::Debug for JoinHandle<T> + where + T: fmt::Debug, + { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("JoinHandle").finish() + } + } + + fn assert_send_sync<T: Send + Sync>() { + } +} diff --git a/src/coop.rs b/src/coop.rs index 27e969c..980cdf8 100644 --- a/src/coop.rs +++ b/src/coop.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + //! Opt-in yield points for improved cooperative scheduling. //! //! A single call to [`poll`] on a top-level task may potentially do a lot of @@ -81,7 +83,7 @@ impl Budget { } } -cfg_rt_threaded! { +cfg_rt_multi_thread! { impl Budget { fn has_remaining(self) -> bool { self.0.map(|budget| budget > 0).unwrap_or(true) @@ -96,14 +98,6 @@ pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R { with_budget(Budget::initial(), f) } -cfg_rt_threaded! { - /// Set the current task's budget - #[cfg(feature = "blocking")] - pub(crate) fn set(budget: Budget) { - CURRENT.with(|cell| cell.set(budget)) - } -} - #[inline(always)] fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R { struct ResetGuard<'a> { @@ -128,14 +122,19 @@ fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R { }) } -cfg_rt_threaded! { +cfg_rt_multi_thread! { + /// Set the current task's budget + pub(crate) fn set(budget: Budget) { + CURRENT.with(|cell| cell.set(budget)) + } + #[inline(always)] pub(crate) fn has_budget_remaining() -> bool { CURRENT.with(|cell| cell.get().has_remaining()) } } -cfg_blocking_impl! { +cfg_rt! { /// Forcibly remove the budgeting constraints early. /// /// Returns the remaining budget diff --git a/src/fs/file.rs b/src/fs/file.rs index f3bc985..7c71f48 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -5,7 +5,8 @@ use self::State::*; use crate::fs::{asyncify, sys}; use crate::io::blocking::Buf; -use crate::io::{AsyncRead, AsyncSeek, AsyncWrite}; +use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; +use crate::sync::Mutex; use std::fmt; use std::fs::{Metadata, Permissions}; @@ -80,12 +81,18 @@ use std::task::Poll::*; /// ``` pub struct File { std: Arc<sys::File>, + inner: Mutex<Inner>, +} + +struct Inner { state: State, /// Errors from writes/flushes are returned in write/flush calls. If a write /// error is observed while performing a read, it is saved until the next /// write / flush call. last_write_err: Option<io::ErrorKind>, + + pos: u64, } #[derive(Debug)] @@ -197,70 +204,11 @@ impl File { pub fn from_std(std: sys::File) -> File { File { std: Arc::new(std), - state: State::Idle(Some(Buf::with_capacity(0))), - last_write_err: None, - } - } - - /// Seeks to an offset, in bytes, in a stream. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::fs::File; - /// use tokio::prelude::*; - /// - /// use std::io::SeekFrom; - /// - /// # async fn dox() -> std::io::Result<()> { - /// let mut file = File::open("foo.txt").await?; - /// file.seek(SeekFrom::Start(6)).await?; - /// - /// let mut contents = vec![0u8; 10]; - /// file.read_exact(&mut contents).await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// The [`read_exact`] method is defined on the [`AsyncReadExt`] trait. - /// - /// [`read_exact`]: fn@crate::io::AsyncReadExt::read_exact - /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt - pub async fn seek(&mut self, mut pos: SeekFrom) -> io::Result<u64> { - self.complete_inflight().await; - - let mut buf = match self.state { - Idle(ref mut buf_cell) => buf_cell.take().unwrap(), - _ => unreachable!(), - }; - - // Factor in any unread data from the buf - if !buf.is_empty() { - let n = buf.discard_read(); - - if let SeekFrom::Current(ref mut offset) = pos { - *offset += n; - } - } - - let std = self.std.clone(); - - // Start the operation - self.state = Busy(sys::run(move || { - let res = (&*std).seek(pos); - (Operation::Seek(res), buf) - })); - - let (op, buf) = match self.state { - Idle(_) => unreachable!(), - Busy(ref mut rx) => rx.await.unwrap(), - }; - - self.state = Idle(Some(buf)); - - match op { - Operation::Seek(res) => res, - _ => unreachable!(), + inner: Mutex::new(Inner { + state: State::Idle(Some(Buf::with_capacity(0))), + last_write_err: None, + pos: 0, + }), } } @@ -287,8 +235,9 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn sync_all(&mut self) -> io::Result<()> { - self.complete_inflight().await; + pub async fn sync_all(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; let std = self.std.clone(); asyncify(move || std.sync_all()).await @@ -321,8 +270,9 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn sync_data(&mut self) -> io::Result<()> { - self.complete_inflight().await; + pub async fn sync_data(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; let std = self.std.clone(); asyncify(move || std.sync_data()).await @@ -358,10 +308,11 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn set_len(&mut self, size: u64) -> io::Result<()> { - self.complete_inflight().await; + pub async fn set_len(&self, size: u64) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; - let mut buf = match self.state { + let mut buf = match inner.state { Idle(ref mut buf_cell) => buf_cell.take().unwrap(), _ => unreachable!(), }; @@ -374,7 +325,7 @@ impl File { let std = self.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -386,15 +337,17 @@ impl File { (Operation::Seek(res), buf) })); - let (op, buf) = match self.state { + let (op, buf) = match inner.state { Idle(_) => unreachable!(), Busy(ref mut rx) => rx.await?, }; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { - Operation::Seek(res) => res.map(|_| ()), + Operation::Seek(res) => res.map(|pos| { + inner.pos = pos; + }), _ => unreachable!(), } } @@ -459,7 +412,7 @@ impl File { /// # } /// ``` pub async fn into_std(mut self) -> sys::File { - self.complete_inflight().await; + self.inner.get_mut().complete_inflight().await; Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") } @@ -526,42 +479,32 @@ impl File { let std = self.std.clone(); asyncify(move || std.set_permissions(perm)).await } - - async fn complete_inflight(&mut self) { - use crate::future::poll_fn; - - if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await { - self.last_write_err = Some(e.kind()); - } - } } impl AsyncRead for File { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668 - false - } - fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll<io::Result<usize>> { + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + loop { - match self.state { + match inner.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { - let n = buf.copy_to(dst); + buf.copy_to(dst); *buf_cell = Some(buf); - return Ready(Ok(n)); + return Ready(Ok(())); } buf.ensure_capacity_for(dst); - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = buf.read_from(&mut &*std); (Operation::Read(res), buf) })); @@ -571,29 +514,32 @@ impl AsyncRead for File { match op { Operation::Read(Ok(_)) => { - let n = buf.copy_to(dst); - self.state = Idle(Some(buf)); - return Ready(Ok(n)); + buf.copy_to(dst); + inner.state = Idle(Some(buf)); + return Ready(Ok(())); } Operation::Read(Err(e)) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); return Ready(Err(e)); } Operation::Write(Ok(_)) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); continue; } Operation::Write(Err(e)) => { - assert!(self.last_write_err.is_none()); - self.last_write_err = Some(e.kind()); - self.state = Idle(Some(buf)); + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); + inner.state = Idle(Some(buf)); } - Operation::Seek(_) => { + Operation::Seek(result) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); + if let Ok(pos) = result { + inner.pos = pos; + } continue; } } @@ -604,13 +550,13 @@ impl AsyncRead for File { } impl AsyncSeek for File { - fn start_seek( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut pos: SeekFrom, - ) -> Poll<io::Result<()>> { + fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + loop { - match self.state { + match inner.state { + Busy(_) => panic!("must wait for poll_complete before calling start_seek"), Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); @@ -623,49 +569,41 @@ impl AsyncSeek for File { } } - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) })); - - return Ready(Ok(())); - } - Busy(ref mut rx) => { - let (op, buf) = ready!(Pin::new(rx).poll(cx))?; - self.state = Idle(Some(buf)); - - match op { - Operation::Read(_) => {} - Operation::Write(Err(e)) => { - assert!(self.last_write_err.is_none()); - self.last_write_err = Some(e.kind()); - } - Operation::Write(_) => {} - Operation::Seek(_) => {} - } + return Ok(()); } } } } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let inner = self.inner.get_mut(); + loop { - match self.state { - Idle(_) => panic!("must call start_seek before calling poll_complete"), + match inner.state { + Idle(_) => return Poll::Ready(Ok(inner.pos)), Busy(ref mut rx) => { let (op, buf) = ready!(Pin::new(rx).poll(cx))?; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { Operation::Read(_) => {} Operation::Write(Err(e)) => { - assert!(self.last_write_err.is_none()); - self.last_write_err = Some(e.kind()); + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); } Operation::Write(_) => {} - Operation::Seek(res) => return Ready(res), + Operation::Seek(res) => { + if let Ok(pos) = res { + inner.pos = pos; + } + return Ready(res); + } } } } @@ -675,16 +613,19 @@ impl AsyncSeek for File { impl AsyncWrite for File { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, src: &[u8], ) -> Poll<io::Result<usize>> { - if let Some(e) = self.last_write_err.take() { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + if let Some(e) = inner.last_write_err.take() { return Ready(Err(e.into())); } loop { - match self.state { + match inner.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); @@ -695,9 +636,9 @@ impl AsyncWrite for File { }; let n = buf.copy_from(src); - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -711,7 +652,7 @@ impl AsyncWrite for File { } Busy(ref mut rx) => { let (op, buf) = ready!(Pin::new(rx).poll(cx))?; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { Operation::Read(_) => { @@ -737,27 +678,12 @@ impl AsyncWrite for File { } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { - if let Some(e) = self.last_write_err.take() { - return Ready(Err(e.into())); - } - - let (op, buf) = match self.state { - Idle(_) => return Ready(Ok(())), - Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?, - }; - - // The buffer is not used here - self.state = Idle(Some(buf)); - - match op { - Operation::Read(_) => Ready(Ok(())), - Operation::Write(res) => Ready(res), - Operation::Seek(_) => Ready(Ok(())), - } + let inner = self.inner.get_mut(); + inner.poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { - Poll::Ready(Ok(())) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + self.poll_flush(cx) } } @@ -782,9 +708,53 @@ impl std::os::unix::io::AsRawFd for File { } } +#[cfg(unix)] +impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self { + sys::File::from_raw_fd(fd).into() + } +} + #[cfg(windows)] impl std::os::windows::io::AsRawHandle for File { fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { self.std.as_raw_handle() } } + +#[cfg(windows)] +impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self { + sys::File::from_raw_handle(handle).into() + } +} + +impl Inner { + async fn complete_inflight(&mut self) { + use crate::future::poll_fn; + + if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await { + self.last_write_err = Some(e.kind()); + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + if let Some(e) = self.last_write_err.take() { + return Ready(Err(e.into())); + } + + let (op, buf) = match self.state { + Idle(_) => return Ready(Ok(())), + Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?, + }; + + // The buffer is not used here + self.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => Ready(Ok(())), + Operation::Write(res) => Ready(res), + Operation::Seek(_) => Ready(Ok(())), + } + } +} diff --git a/src/fs/mod.rs b/src/fs/mod.rs index a2b062b..d2757a5 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -107,6 +107,6 @@ mod sys { pub(crate) use std::fs::File; // TODO: don't rename - pub(crate) use crate::runtime::spawn_blocking as run; - pub(crate) use crate::task::JoinHandle as Blocking; + pub(crate) use crate::blocking::spawn_blocking as run; + pub(crate) use crate::blocking::JoinHandle as Blocking; } diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs index ba3d9a6..acd99a1 100644 --- a/src/fs/open_options.rs +++ b/src/fs/open_options.rs @@ -383,8 +383,7 @@ impl OpenOptions { Ok(File::from_std(std)) } - /// Returns a mutable reference to the the underlying std::fs::OpenOptions - #[cfg(unix)] + /// Returns a mutable reference to the underlying `std::fs::OpenOptions` pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { &mut self.0 } diff --git a/src/fs/os/unix/dir_builder_ext.rs b/src/fs/os/unix/dir_builder_ext.rs index e9a25b9..ccdc552 100644 --- a/src/fs/os/unix/dir_builder_ext.rs +++ b/src/fs/os/unix/dir_builder_ext.rs @@ -3,7 +3,7 @@ use crate::fs::dir_builder::DirBuilder; /// Unix-specific extensions to [`DirBuilder`]. /// /// [`DirBuilder`]: crate::fs::DirBuilder -pub trait DirBuilderExt { +pub trait DirBuilderExt: sealed::Sealed { /// Sets the mode to create new directories with. /// /// This option defaults to 0o777. @@ -27,3 +27,10 @@ impl DirBuilderExt for DirBuilder { self } } + +impl sealed::Sealed for DirBuilder {} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait Sealed {} +} diff --git a/src/fs/os/unix/dir_entry_ext.rs b/src/fs/os/unix/dir_entry_ext.rs new file mode 100644 index 0000000..2ac56da --- /dev/null +++ b/src/fs/os/unix/dir_entry_ext.rs @@ -0,0 +1,44 @@ +use crate::fs::DirEntry; +use std::os::unix::fs::DirEntryExt as _; + +/// Unix-specific extension methods for [`fs::DirEntry`]. +/// +/// This mirrors the definition of [`std::os::unix::fs::DirEntryExt`]. +/// +/// [`fs::DirEntry`]: crate::fs::DirEntry +/// [`std::os::unix::fs::DirEntryExt`]: std::os::unix::fs::DirEntryExt +pub trait DirEntryExt: sealed::Sealed { + /// Returns the underlying `d_ino` field in the contained `dirent` + /// structure. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// use tokio::fs::os::unix::DirEntryExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// while let Some(entry) = entries.next_entry().await? { + /// // Here, `entry` is a `DirEntry`. + /// println!("{:?}: {}", entry.file_name(), entry.ino()); + /// } + /// # Ok(()) + /// # } + /// ``` + fn ino(&self) -> u64; +} + +impl DirEntryExt for DirEntry { + fn ino(&self) -> u64 { + self.as_inner().ino() + } +} + +impl sealed::Sealed for DirEntry {} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait Sealed {} +} diff --git a/src/fs/os/unix/mod.rs b/src/fs/os/unix/mod.rs index 826222e..a0ae751 100644 --- a/src/fs/os/unix/mod.rs +++ b/src/fs/os/unix/mod.rs @@ -8,3 +8,6 @@ pub use self::open_options_ext::OpenOptionsExt; mod dir_builder_ext; pub use self::dir_builder_ext::DirBuilderExt; + +mod dir_entry_ext; +pub use self::dir_entry_ext::DirEntryExt; diff --git a/src/fs/os/unix/open_options_ext.rs b/src/fs/os/unix/open_options_ext.rs index ff89275..6e0fd2b 100644 --- a/src/fs/os/unix/open_options_ext.rs +++ b/src/fs/os/unix/open_options_ext.rs @@ -1,14 +1,13 @@ use crate::fs::open_options::OpenOptions; -use std::os::unix::fs::OpenOptionsExt as StdOpenOptionsExt; +use std::os::unix::fs::OpenOptionsExt as _; /// Unix-specific extensions to [`fs::OpenOptions`]. /// /// This mirrors the definition of [`std::os::unix::fs::OpenOptionsExt`]. /// -/// /// [`fs::OpenOptions`]: crate::fs::OpenOptions /// [`std::os::unix::fs::OpenOptionsExt`]: std::os::unix::fs::OpenOptionsExt -pub trait OpenOptionsExt { +pub trait OpenOptionsExt: sealed::Sealed { /// Sets the mode bits that a new file will be created with. /// /// If a new file is created as part of an `OpenOptions::open` call then this @@ -77,3 +76,10 @@ impl OpenOptionsExt for OpenOptions { self } } + +impl sealed::Sealed for OpenOptions {} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait Sealed {} +} diff --git a/src/fs/os/windows/mod.rs b/src/fs/os/windows/mod.rs index 42eb7bd..ab98c13 100644 --- a/src/fs/os/windows/mod.rs +++ b/src/fs/os/windows/mod.rs @@ -5,3 +5,6 @@ pub use self::symlink_dir::symlink_dir; mod symlink_file; pub use self::symlink_file::symlink_file; + +mod open_options_ext; +pub use self::open_options_ext::OpenOptionsExt; diff --git a/src/fs/os/windows/open_options_ext.rs b/src/fs/os/windows/open_options_ext.rs new file mode 100644 index 0000000..ce86fba --- /dev/null +++ b/src/fs/os/windows/open_options_ext.rs @@ -0,0 +1,214 @@ +use crate::fs::open_options::OpenOptions; +use std::os::windows::fs::OpenOptionsExt as _; + +/// Unix-specific extensions to [`fs::OpenOptions`]. +/// +/// This mirrors the definition of [`std::os::windows::fs::OpenOptionsExt`]. +/// +/// [`fs::OpenOptions`]: crate::fs::OpenOptions +/// [`std::os::windows::fs::OpenOptionsExt`]: std::os::windows::fs::OpenOptionsExt +pub trait OpenOptionsExt: sealed::Sealed { + /// Overrides the `dwDesiredAccess` argument to the call to [`CreateFile`] + /// with the specified value. + /// + /// This will override the `read`, `write`, and `append` flags on the + /// `OpenOptions` structure. This method provides fine-grained control over + /// the permissions to read, write and append data, attributes (like hidden + /// and system), and extended attributes. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::windows::OpenOptionsExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Open without read and write permission, for example if you only need + /// // to call `stat` on the file + /// let file = OpenOptions::new().access_mode(0).open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + fn access_mode(&mut self, access: u32) -> &mut Self; + + /// Overrides the `dwShareMode` argument to the call to [`CreateFile`] with + /// the specified value. + /// + /// By default `share_mode` is set to + /// `FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE`. This allows + /// other processes to read, write, and delete/rename the same file + /// while it is open. Removing any of the flags will prevent other + /// processes from performing the corresponding operation until the file + /// handle is closed. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::windows::OpenOptionsExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Do not allow others to read or modify this file while we have it open + /// // for writing. + /// let file = OpenOptions::new() + /// .write(true) + /// .share_mode(0) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + fn share_mode(&mut self, val: u32) -> &mut Self; + + /// Sets extra flags for the `dwFileFlags` argument to the call to + /// [`CreateFile2`] to the specified value (or combines it with + /// `attributes` and `security_qos_flags` to set the `dwFlagsAndAttributes` + /// for [`CreateFile`]). + /// + /// Custom flags can only set flags, not remove flags set by Rust's options. + /// This option overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::FILE_FLAG_DELETE_ON_CLOSE; + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::windows::OpenOptionsExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .create(true) + /// .write(true) + /// .custom_flags(FILE_FLAG_DELETE_ON_CLOSE) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + fn custom_flags(&mut self, flags: u32) -> &mut Self; + + /// Sets the `dwFileAttributes` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and + /// `security_qos_flags` to set the `dwFlagsAndAttributes` for + /// [`CreateFile`]). + /// + /// If a _new_ file is created because it does not yet exist and + /// `.create(true)` or `.create_new(true)` are specified, the new file is + /// given the attributes declared with `.attributes()`. + /// + /// If an _existing_ file is opened with `.create(true).truncate(true)`, its + /// existing attributes are preserved and combined with the ones declared + /// with `.attributes()`. + /// + /// In all other cases the attributes get ignored. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winnt::FILE_ATTRIBUTE_HIDDEN; + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::windows::OpenOptionsExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// .attributes(FILE_ATTRIBUTE_HIDDEN) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + fn attributes(&mut self, val: u32) -> &mut Self; + + /// Sets the `dwSecurityQosFlags` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and `attributes` + /// to set the `dwFlagsAndAttributes` for [`CreateFile`]). + /// + /// By default `security_qos_flags` is not set. It should be specified when + /// opening a named pipe, to control to which degree a server process can + /// act on behalf of a client process (security impersonation level). + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::SECURITY_IDENTIFICATION; + /// use tokio::fs::OpenOptions; + /// use tokio::fs::os::windows::OpenOptionsExt; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// + /// // Sets the flag value to `SecurityIdentification`. + /// .security_qos_flags(SECURITY_IDENTIFICATION) + /// + /// .open(r"\\.\pipe\MyPipe").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + /// [Impersonation Levels]: + /// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + fn security_qos_flags(&mut self, flags: u32) -> &mut Self; +} + +impl OpenOptionsExt for OpenOptions { + fn access_mode(&mut self, access: u32) -> &mut OpenOptions { + self.as_inner_mut().access_mode(access); + self + } + + fn share_mode(&mut self, share: u32) -> &mut OpenOptions { + self.as_inner_mut().share_mode(share); + self + } + + fn custom_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } + + fn attributes(&mut self, attributes: u32) -> &mut OpenOptions { + self.as_inner_mut().attributes(attributes); + self + } + + fn security_qos_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().security_qos_flags(flags); + self + } +} + +impl sealed::Sealed for OpenOptions {} + +pub(crate) mod sealed { + #[doc(hidden)] + pub trait Sealed {} +} diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs index f9b16c6..8ca583b 100644 --- a/src/fs/read_dir.rs +++ b/src/fs/read_dir.rs @@ -4,8 +4,6 @@ use std::ffi::OsString; use std::fs::{FileType, Metadata}; use std::future::Future; use std::io; -#[cfg(unix)] -use std::os::unix::fs::DirEntryExt; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; @@ -55,8 +53,7 @@ impl ReadDir { poll_fn(|cx| self.poll_next_entry(cx)).await } - #[doc(hidden)] - pub fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> { + fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> { loop { match self.0 { State::Idle(ref mut std) => { @@ -234,11 +231,10 @@ impl DirEntry { let std = self.0.clone(); asyncify(move || std.file_type()).await } -} -#[cfg(unix)] -impl DirEntryExt for DirEntry { - fn ino(&self) -> u64 { - self.0.ino() + /// Returns a reference to the underlying `std::fs::DirEntry` + #[cfg(unix)] + pub(super) fn as_inner(&self) -> &std::fs::DirEntry { + &self.0 } } diff --git a/src/future/block_on.rs b/src/future/block_on.rs new file mode 100644 index 0000000..91f9cc0 --- /dev/null +++ b/src/future/block_on.rs @@ -0,0 +1,15 @@ +use std::future::Future; + +cfg_rt! { + pub(crate) fn block_on<F: Future>(f: F) -> F::Output { + let mut e = crate::runtime::enter::enter(false); + e.block_on(f).unwrap() + } +} + +cfg_not_rt! { + pub(crate) fn block_on<F: Future>(f: F) -> F::Output { + let mut park = crate::park::thread::CachedParkThread::new(); + park.block_on(f).unwrap() + } +} diff --git a/src/future/mod.rs b/src/future/mod.rs index 770753f..f7d93c9 100644 --- a/src/future/mod.rs +++ b/src/future/mod.rs @@ -1,15 +1,24 @@ -#![allow(unused_imports, dead_code)] +#![cfg_attr(not(feature = "macros"), allow(unreachable_pub))] //! Asynchronous values. -mod maybe_done; -pub use maybe_done::{maybe_done, MaybeDone}; +#[cfg(any(feature = "macros", feature = "process"))] +pub(crate) mod maybe_done; mod poll_fn; pub use poll_fn::poll_fn; -mod ready; -pub(crate) use ready::{ok, Ready}; +cfg_not_loom! { + mod ready; + pub(crate) use ready::{ok, Ready}; +} -mod try_join; -pub(crate) use try_join::try_join3; +cfg_process! { + mod try_join; + pub(crate) use try_join::try_join3; +} + +cfg_sync! { + mod block_on; + pub(crate) use block_on::block_on; +} diff --git a/src/future/pending.rs b/src/future/pending.rs deleted file mode 100644 index 287e836..0000000 --- a/src/future/pending.rs +++ /dev/null @@ -1,44 +0,0 @@ -use sdt::pin::Pin; -use std::future::Future; -use std::marker; -use std::task::{Context, Poll}; - -/// Future for the [`pending()`] function. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -struct Pending<T> { - _data: marker::PhantomData<T>, -} - -/// Creates a future which never resolves, representing a computation that never -/// finishes. -/// -/// The returned future will forever return [`Poll::Pending`]. -/// -/// # Examples -/// -/// ```no_run -/// use tokio::future; -/// -/// #[tokio::main] -/// async fn main { -/// future::pending().await; -/// unreachable!(); -/// } -/// ``` -pub async fn pending() -> ! { - Pending { - _data: marker::PhantomData, - } - .await -} - -impl<T> Future for Pending<T> { - type Output = !; - - fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<T> { - Poll::Pending - } -} - -impl<T> Unpin for Pending<T> {} diff --git a/src/future/poll_fn.rs b/src/future/poll_fn.rs index 9b3d137..0169bd5 100644 --- a/src/future/poll_fn.rs +++ b/src/future/poll_fn.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + //! Definition of the `PollFn` adapter combinator use std::fmt; diff --git a/src/future/try_join.rs b/src/future/try_join.rs index 5bd80dc..8943f61 100644 --- a/src/future/try_join.rs +++ b/src/future/try_join.rs @@ -1,4 +1,4 @@ -use crate::future::{maybe_done, MaybeDone}; +use crate::future::maybe_done::{maybe_done, MaybeDone}; use pin_project_lite::pin_project; use std::future::Future; diff --git a/src/io/async_read.rs b/src/io/async_read.rs index 1aef415..d075443 100644 --- a/src/io/async_read.rs +++ b/src/io/async_read.rs @@ -1,6 +1,5 @@ -use bytes::BufMut; +use super::ReadBuf; use std::io; -use std::mem::MaybeUninit; use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -16,9 +15,10 @@ use std::task::{Context, Poll}; /// Specifically, this means that the `poll_read` function will return one of /// the following: /// -/// * `Poll::Ready(Ok(n))` means that `n` bytes of data was immediately read -/// and placed into the output buffer, where `n` == 0 implies that EOF has -/// been reached. +/// * `Poll::Ready(Ok(()))` means that data was immediately read and placed into +/// the output buffer. The amount of data read can be determined by the +/// increase in the length of the slice returned by `ReadBuf::filled`. If the +/// difference is 0, EOF has been reached. /// /// * `Poll::Pending` means that no data was read into the buffer /// provided. The I/O object is not currently readable but may become readable @@ -41,110 +41,29 @@ use std::task::{Context, Poll}; /// [`Read::read`]: std::io::Read::read /// [`AsyncReadExt`]: crate::io::AsyncReadExt pub trait AsyncRead { - /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns - /// `true` if the supplied buffer was zeroed out. - /// - /// While it would be highly unusual, implementations of [`io::Read`] are - /// able to read data from the buffer passed as an argument. Because of - /// this, the buffer passed to [`io::Read`] must be initialized memory. In - /// situations where large numbers of buffers are used, constantly having to - /// zero out buffers can be expensive. - /// - /// This function does any necessary work to prepare an uninitialized buffer - /// to be safe to pass to `read`. If `read` guarantees to never attempt to - /// read data out of the supplied buffer, then `prepare_uninitialized_buffer` - /// doesn't need to do any work. - /// - /// If this function returns `true`, then the memory has been zeroed out. - /// This allows implementations of `AsyncRead` which are composed of - /// multiple subimplementations to efficiently implement - /// `prepare_uninitialized_buffer`. - /// - /// This function isn't actually `unsafe` to call but `unsafe` to implement. - /// The implementer must ensure that either the whole `buf` has been zeroed - /// or `poll_read_buf()` overwrites the buffer without reading it and returns - /// correct value. - /// - /// This function is called from [`poll_read_buf`]. - /// - /// # Safety - /// - /// Implementations that return `false` must never read from data slices - /// that they did not write to. - /// - /// [`io::Read`]: std::io::Read - /// [`poll_read_buf`]: method@Self::poll_read_buf - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - for x in buf { - *x = MaybeUninit::new(0); - } - - true - } - /// Attempts to read from the `AsyncRead` into `buf`. /// - /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + /// On success, returns `Poll::Ready(Ok(()))` and fills `buf` with data + /// read. If no data was read (`buf.filled().is_empty()`) it implies that + /// EOF has been reached. /// - /// If no data is available for reading, the method returns - /// `Poll::Pending` and arranges for the current task (via - /// `cx.waker()`) to receive a notification when the object becomes - /// readable or is closed. + /// If no data is available for reading, the method returns `Poll::Pending` + /// and arranges for the current task (via `cx.waker()`) to receive a + /// notification when the object becomes readable or is closed. fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>>; - - /// Pulls some bytes from this source into the specified `BufMut`, returning - /// how many bytes were read. - /// - /// The `buf` provided will have bytes read into it and the internal cursor - /// will be advanced if any bytes were read. Note that this method typically - /// will not reallocate the buffer provided. - fn poll_read_buf<B: BufMut>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> - where - Self: Sized, - { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - unsafe { - let n = { - let b = buf.bytes_mut(); - - self.prepare_uninitialized_buffer(b); - - // Convert to `&mut [u8]` - let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]); - - let n = ready!(self.poll_read(cx, b))?; - assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold"); - n - }; - - buf.advance_mut(n); - Poll::Ready(Ok(n)) - } - } + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>>; } macro_rules! deref_async_read { () => { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut **self).poll_read(cx, buf) } }; @@ -163,43 +82,50 @@ where P: DerefMut + Unpin, P::Target: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.get_mut().as_mut().poll_read(cx, buf) } } impl AsyncRead for &[u8] { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let amt = std::cmp::min(self.len(), buf.remaining()); + let (a, b) = self.split_at(amt); + buf.put_slice(a); + *self = b; + Poll::Ready(Ok(())) } } impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let pos = self.position(); + let slice: &[u8] = (*self).get_ref().as_ref(); + + // The position could technically be out of bounds, so don't panic... + if pos > slice.len() as u64 { + return Poll::Ready(Ok(())); + } + + let start = pos as usize; + let amt = std::cmp::min(slice.len() - start, buf.remaining()); + // Add won't overflow because of pos check above. + let end = start + amt; + buf.put_slice(&slice[start..end]); + self.set_position(end as u64); + + Poll::Ready(Ok(())) } } diff --git a/src/io/async_seek.rs b/src/io/async_seek.rs index 32ed0a2..bd7a992 100644 --- a/src/io/async_seek.rs +++ b/src/io/async_seek.rs @@ -23,36 +23,33 @@ pub trait AsyncSeek { /// /// If this function returns successfully, then the job has been submitted. /// To find out when it completes, call `poll_complete`. - fn start_seek( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - position: SeekFrom, - ) -> Poll<io::Result<()>>; + /// + /// # Errors + /// + /// This function can return [`io::ErrorKind::Other`] in case there is + /// another seek in progress. To avoid this, it is advisable that any call + /// to `start_seek` is preceded by a call to `poll_complete` to ensure all + /// pending seeks have completed. + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()>; /// Waits for a seek operation to complete. /// /// If the seek operation completed successfully, /// this method returns the new position from the start of the stream. - /// That position can be used later with [`SeekFrom::Start`]. + /// That position can be used later with [`SeekFrom::Start`]. Repeatedly + /// calling this function without calling `start_seek` might return the + /// same result. /// /// # Errors /// /// Seeking to a negative offset is considered an error. - /// - /// # Panics - /// - /// Calling this method without calling `start_seek` first is an error. fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>; } macro_rules! deref_async_seek { () => { - fn start_seek( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - Pin::new(&mut **self).start_seek(cx, pos) + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + Pin::new(&mut **self).start_seek(pos) } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { @@ -74,12 +71,8 @@ where P: DerefMut + Unpin, P::Target: AsyncSeek, { - fn start_seek( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - self.get_mut().as_mut().start_seek(cx, pos) + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.get_mut().as_mut().start_seek(pos) } fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { @@ -88,12 +81,8 @@ where } impl<T: AsRef<[u8]> + Unpin> AsyncSeek for io::Cursor<T> { - fn start_seek( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - Poll::Ready(io::Seek::seek(&mut *self, pos).map(drop)) + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + io::Seek::seek(&mut *self, pos).map(drop) } fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> { Poll::Ready(Ok(self.get_mut().position())) diff --git a/src/io/async_write.rs b/src/io/async_write.rs index ecf7575..66ba4bf 100644 --- a/src/io/async_write.rs +++ b/src/io/async_write.rs @@ -1,4 +1,3 @@ -use bytes::Buf; use std::io; use std::ops::DerefMut; use std::pin::Pin; @@ -128,27 +127,6 @@ pub trait AsyncWrite { /// This function will panic if not called within the context of a future's /// task. fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>; - - /// Writes a `Buf` into this value, returning how many bytes were written. - /// - /// Note that this method will advance the `buf` provided automatically by - /// the number of bytes written. - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<Result<usize, io::Error>> - where - Self: Sized, - { - if !buf.has_remaining() { - return Poll::Ready(Ok(0)); - } - - let n = ready!(self.poll_write(cx, buf.bytes()))?; - buf.advance(n); - Poll::Ready(Ok(n)) - } } macro_rules! deref_async_write { diff --git a/src/io/blocking.rs b/src/io/blocking.rs index 2491039..430801e 100644 --- a/src/io/blocking.rs +++ b/src/io/blocking.rs @@ -1,5 +1,5 @@ use crate::io::sys; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::cmp; use std::future::Future; @@ -53,17 +53,17 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll<io::Result<usize>> { + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { loop { match self.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { - let n = buf.copy_to(dst); + buf.copy_to(dst); *buf_cell = Some(buf); - return Ready(Ok(n)); + return Ready(Ok(())); } buf.ensure_capacity_for(dst); @@ -80,9 +80,9 @@ where match res { Ok(_) => { - let n = buf.copy_to(dst); + buf.copy_to(dst); self.state = Idle(Some(buf)); - return Ready(Ok(n)); + return Ready(Ok(())); } Err(e) => { assert!(buf.is_empty()); @@ -203,9 +203,9 @@ impl Buf { self.buf.len() - self.pos } - pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize { - let n = cmp::min(self.len(), dst.len()); - dst[..n].copy_from_slice(&self.bytes()[..n]); + pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize { + let n = cmp::min(self.len(), dst.remaining()); + dst.put_slice(&self.bytes()[..n]); self.pos += n; if self.pos == self.buf.len() { @@ -229,10 +229,10 @@ impl Buf { &self.buf[self.pos..] } - pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) { + pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) { assert!(self.is_empty()); - let len = cmp::min(bytes.len(), MAX_BUF); + let len = cmp::min(bytes.remaining(), MAX_BUF); if self.buf.len() < len { self.buf.reserve(len - self.buf.len()); diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs index dbfb6e1..cd82b26 100644 --- a/src/io/driver/mod.rs +++ b/src/io/driver/mod.rs @@ -1,30 +1,38 @@ -pub(crate) mod platform; +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +mod ready; +use ready::Ready; mod scheduled_io; pub(crate) use scheduled_io::ScheduledIo; // pub(crate) for tests -use crate::loom::sync::atomic::AtomicUsize; use crate::park::{Park, Unpark}; -use crate::runtime::context; -use crate::util::slab::{Address, Slab}; +use crate::util::bit; +use crate::util::slab::{self, Slab}; -use mio::event::Evented; use std::fmt; use std::io; -use std::sync::atomic::Ordering::SeqCst; use std::sync::{Arc, Weak}; -use std::task::Waker; use std::time::Duration; /// I/O driver, backed by Mio pub(crate) struct Driver { + /// Tracks the number of times `turn` is called. It is safe for this to wrap + /// as it is mostly used to determine when to call `compact()` + tick: u8, + /// Reuse the `mio::Events` value across calls to poll. - events: mio::Events, + events: Option<mio::Events>, + + /// Primary slab handle containing the state for each resource registered + /// with this driver. + resources: Slab<ScheduledIo>, + + /// The system event queue + poll: mio::Poll, /// State shared between the reactor and the handles. inner: Arc<Inner>, - - _wakeup_registration: mio::Registration, } /// A reference to an I/O driver @@ -33,18 +41,20 @@ pub(crate) struct Handle { inner: Weak<Inner>, } -pub(super) struct Inner { - /// The underlying system event queue. - io: mio::Poll, +pub(crate) struct ReadyEvent { + tick: u8, + ready: Ready, +} - /// Dispatch slabs for I/O and futures events - pub(super) io_dispatch: Slab<ScheduledIo>, +pub(super) struct Inner { + /// Registers I/O resources + registry: mio::Registry, - /// The number of sources in `io_dispatch`. - n_sources: AtomicUsize, + /// Allocates `ScheduledIo` handles when creating new resources. + pub(super) io_dispatch: slab::Allocator<ScheduledIo>, /// Used to wake up the reactor from a call to `turn` - wakeup: mio::SetReadiness, + waker: mio::Waker, } #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -53,7 +63,24 @@ pub(super) enum Direction { Write, } -const TOKEN_WAKEUP: mio::Token = mio::Token(Address::NULL); +enum Tick { + Set(u8), + Clear(u8), +} + +// TODO: Don't use a fake token. Instead, reserve a slot entry for the wakeup +// token. +const TOKEN_WAKEUP: mio::Token = mio::Token(1 << 31); + +const ADDRESS: bit::Pack = bit::Pack::least_significant(24); + +// Packs the generation value in the `readiness` field. +// +// The generation prevents a race condition where a slab slot is reused for a +// new socket while the I/O driver is about to apply a readiness event. The +// generaton value is checked when setting new readiness. If the generation do +// not match, then the readiness event is discarded. +const GENERATION: bit::Pack = ADDRESS.then(7); fn _assert_kinds() { fn _assert<T: Send + Sync>() {} @@ -67,24 +94,22 @@ impl Driver { /// Creates a new event loop, returning any error that happened during the /// creation. pub(crate) fn new() -> io::Result<Driver> { - let io = mio::Poll::new()?; - let wakeup_pair = mio::Registration::new2(); + let poll = mio::Poll::new()?; + let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?; + let registry = poll.registry().try_clone()?; - io.register( - &wakeup_pair.0, - TOKEN_WAKEUP, - mio::Ready::readable(), - mio::PollOpt::level(), - )?; + let slab = Slab::new(); + let allocator = slab.allocator(); Ok(Driver { - events: mio::Events::with_capacity(1024), - _wakeup_registration: wakeup_pair.0, + tick: 0, + events: Some(mio::Events::with_capacity(1024)), + resources: slab, + poll, inner: Arc::new(Inner { - io, - io_dispatch: Slab::new(), - n_sources: AtomicUsize::new(0), - wakeup: wakeup_pair.1, + registry, + io_dispatch: allocator, + waker, }), }) } @@ -102,65 +127,66 @@ impl Driver { } fn turn(&mut self, max_wait: Option<Duration>) -> io::Result<()> { + // How often to call `compact()` on the resource slab + const COMPACT_INTERVAL: u8 = 255; + + self.tick = self.tick.wrapping_add(1); + + if self.tick == COMPACT_INTERVAL { + self.resources.compact(); + } + + let mut events = self.events.take().expect("i/o driver event store missing"); + // Block waiting for an event to happen, peeling out how many events // happened. - match self.inner.io.poll(&mut self.events, max_wait) { + match self.poll.poll(&mut events, max_wait) { Ok(_) => {} + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } // Process all the events that came in, dispatching appropriately - - for event in self.events.iter() { + for event in events.iter() { let token = event.token(); - if token == TOKEN_WAKEUP { - self.inner - .wakeup - .set_readiness(mio::Ready::empty()) - .unwrap(); - } else { - self.dispatch(token, event.readiness()); + if token != TOKEN_WAKEUP { + self.dispatch(token, Ready::from_mio(event)); } } + self.events = Some(events); + Ok(()) } - fn dispatch(&self, token: mio::Token, ready: mio::Ready) { - let mut rd = None; - let mut wr = None; + fn dispatch(&mut self, token: mio::Token, ready: Ready) { + let addr = slab::Address::from_usize(ADDRESS.unpack(token.0)); - let address = Address::from_usize(token.0); - - let io = match self.inner.io_dispatch.get(address) { + let io = match self.resources.get(addr) { Some(io) => io, None => return, }; - if io - .set_readiness(address, |curr| curr | ready.as_usize()) - .is_err() - { + let res = io.set_readiness(Some(token.0), Tick::Set(self.tick), |curr| curr | ready); + + if res.is_err() { // token no longer valid! return; } - if ready.is_writable() || platform::is_hup(ready) || platform::is_error(ready) { - wr = io.writer.take_waker(); - } - - if !(ready & (!mio::Ready::writable())).is_empty() { - rd = io.reader.take_waker(); - } - - if let Some(w) = rd { - w.wake(); - } + io.wake(ready); + } +} - if let Some(w) = wr { - w.wake(); - } +impl Drop for Driver { + fn drop(&mut self) { + self.resources.for_each(|io| { + // If a task is waiting on the I/O resource, notify it. The task + // will then attempt to use the I/O resource and fail due to the + // driver being shutdown. + io.wake(Ready::ALL); + }) } } @@ -181,6 +207,8 @@ impl Park for Driver { self.turn(Some(duration))?; Ok(()) } + + fn shutdown(&mut self) {} } impl fmt::Debug for Driver { @@ -191,17 +219,36 @@ impl fmt::Debug for Driver { // ===== impl Handle ===== -impl Handle { - /// Returns a handle to the current reactor - /// - /// # Panics - /// - /// This function panics if there is no current reactor set. - pub(super) fn current() -> Self { - context::io_handle() - .expect("there is no reactor running, must be called from the context of Tokio runtime") +cfg_rt! { + impl Handle { + /// Returns a handle to the current reactor + /// + /// # Panics + /// + /// This function panics if there is no current reactor set and `rt` feature + /// flag is not enabled. + pub(super) fn current() -> Self { + crate::runtime::context::io_handle() + .expect("there is no reactor running, must be called from the context of Tokio runtime") + } + } +} + +cfg_not_rt! { + impl Handle { + /// Returns a handle to the current reactor + /// + /// # Panics + /// + /// This function panics if there is no current reactor set, or if the `rt` + /// feature flag is not enabled. + pub(super) fn current() -> Self { + panic!("there is no reactor running, must be called from the context of Tokio runtime with `rt` enabled.") + } } +} +impl Handle { /// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise /// makes the next call to `turn` return immediately. /// @@ -213,7 +260,7 @@ impl Handle { /// return immediately. fn wakeup(&self) { if let Some(inner) = self.inner() { - inner.wakeup.set_readiness(mio::Ready::readable()).unwrap(); + inner.waker.wake().expect("failed to wake I/O driver"); } } @@ -242,159 +289,35 @@ impl Inner { /// The registration token is returned. pub(super) fn add_source( &self, - source: &dyn Evented, - ready: mio::Ready, - ) -> io::Result<Address> { - let address = self.io_dispatch.alloc().ok_or_else(|| { + source: &mut impl mio::event::Source, + interest: mio::Interest, + ) -> io::Result<slab::Ref<ScheduledIo>> { + let (address, shared) = self.io_dispatch.allocate().ok_or_else(|| { io::Error::new( io::ErrorKind::Other, "reactor at max registered I/O resources", ) })?; - self.n_sources.fetch_add(1, SeqCst); + let token = GENERATION.pack(shared.generation(), ADDRESS.pack(address.as_usize(), 0)); - self.io.register( - source, - mio::Token(address.to_usize()), - ready, - mio::PollOpt::edge(), - )?; + self.registry + .register(source, mio::Token(token), interest)?; - Ok(address) + Ok(shared) } /// Deregisters an I/O resource from the reactor. - pub(super) fn deregister_source(&self, source: &dyn Evented) -> io::Result<()> { - self.io.deregister(source) - } - - pub(super) fn drop_source(&self, address: Address) { - self.io_dispatch.remove(address); - self.n_sources.fetch_sub(1, SeqCst); - } - - /// Registers interest in the I/O resource associated with `token`. - pub(super) fn register(&self, token: Address, dir: Direction, w: Waker) { - let sched = self - .io_dispatch - .get(token) - .unwrap_or_else(|| panic!("IO resource for token {:?} does not exist!", token)); - - let waker = match dir { - Direction::Read => &sched.reader, - Direction::Write => &sched.writer, - }; - - waker.register(w); + pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> { + self.registry.deregister(source) } } impl Direction { - pub(super) fn mask(self) -> mio::Ready { + pub(super) fn mask(self) -> Ready { match self { - Direction::Read => { - // Everything except writable is signaled through read. - mio::Ready::all() - mio::Ready::writable() - } - Direction::Write => mio::Ready::writable() | platform::hup() | platform::error(), - } - } -} - -#[cfg(all(test, loom))] -mod tests { - use super::*; - use loom::thread; - - // No-op `Evented` impl just so we can have something to pass to `add_source`. - struct NotEvented; - - impl Evented for NotEvented { - fn register( - &self, - _: &mio::Poll, - _: mio::Token, - _: mio::Ready, - _: mio::PollOpt, - ) -> io::Result<()> { - Ok(()) - } - - fn reregister( - &self, - _: &mio::Poll, - _: mio::Token, - _: mio::Ready, - _: mio::PollOpt, - ) -> io::Result<()> { - Ok(()) - } - - fn deregister(&self, _: &mio::Poll) -> io::Result<()> { - Ok(()) + Direction::Read => Ready::READABLE | Ready::READ_CLOSED, + Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED, } } - - #[test] - fn tokens_unique_when_dropped() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let thread = thread::spawn(move || { - inner2.drop_source(token_1); - }); - - let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } - - #[test] - fn tokens_unique_when_dropped_on_full_page() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - // add sources to fill up the first page so that the dropped index - // may be reused. - for _ in 0..31 { - inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - } - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let thread = thread::spawn(move || { - inner2.drop_source(token_1); - }); - - let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } - - #[test] - fn tokens_unique_concurrent_add() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - - let thread = thread::spawn(move || { - let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap(); - token_2 - }); - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let token_2 = thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } } diff --git a/src/io/driver/ready.rs b/src/io/driver/ready.rs new file mode 100644 index 0000000..8b556e9 --- /dev/null +++ b/src/io/driver/ready.rs @@ -0,0 +1,187 @@ +use std::fmt; +use std::ops; + +const READABLE: usize = 0b0_01; +const WRITABLE: usize = 0b0_10; +const READ_CLOSED: usize = 0b0_0100; +const WRITE_CLOSED: usize = 0b0_1000; + +/// A set of readiness event kinds. +/// +/// `Ready` is set of operation descriptors indicating which kind of an +/// operation is ready to be performed. +/// +/// This struct only represents portable event kinds. Portable events are +/// events that can be raised on any platform while guaranteeing no false +/// positives. +#[derive(Clone, Copy, PartialEq, PartialOrd)] +pub(crate) struct Ready(usize); + +impl Ready { + /// Returns the empty `Ready` set. + pub(crate) const EMPTY: Ready = Ready(0); + + /// Returns a `Ready` representing readable readiness. + pub(crate) const READABLE: Ready = Ready(READABLE); + + /// Returns a `Ready` representing writable readiness. + pub(crate) const WRITABLE: Ready = Ready(WRITABLE); + + /// Returns a `Ready` representing read closed readiness. + pub(crate) const READ_CLOSED: Ready = Ready(READ_CLOSED); + + /// Returns a `Ready` representing write closed readiness. + pub(crate) const WRITE_CLOSED: Ready = Ready(WRITE_CLOSED); + + /// Returns a `Ready` representing readiness for all operations. + pub(crate) const ALL: Ready = Ready(READABLE | WRITABLE | READ_CLOSED | WRITE_CLOSED); + + pub(crate) fn from_mio(event: &mio::event::Event) -> Ready { + let mut ready = Ready::EMPTY; + + if event.is_readable() { + ready |= Ready::READABLE; + } + + if event.is_writable() { + ready |= Ready::WRITABLE; + } + + if event.is_read_closed() { + ready |= Ready::READ_CLOSED; + } + + if event.is_write_closed() { + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + /// Returns true if `Ready` is the empty set + pub(crate) fn is_empty(self) -> bool { + self == Ready::EMPTY + } + + /// Returns true if the value includes readable readiness + pub(crate) fn is_readable(self) -> bool { + self.contains(Ready::READABLE) || self.is_read_closed() + } + + /// Returns true if the value includes writable readiness + pub(crate) fn is_writable(self) -> bool { + self.contains(Ready::WRITABLE) || self.is_write_closed() + } + + /// Returns true if the value includes read closed readiness + pub(crate) fn is_read_closed(self) -> bool { + self.contains(Ready::READ_CLOSED) + } + + /// Returns true if the value includes write closed readiness + pub(crate) fn is_write_closed(self) -> bool { + self.contains(Ready::WRITE_CLOSED) + } + + /// Returns true if `self` is a superset of `other`. + /// + /// `other` may represent more than one readiness operations, in which case + /// the function only returns true if `self` contains all readiness + /// specified in `other`. + pub(crate) fn contains<T: Into<Self>>(self, other: T) -> bool { + let other = other.into(); + (self & other) == other + } + + /// Create a `Ready` instance using the given `usize` representation. + /// + /// The `usize` representation must have been obtained from a call to + /// `Readiness::as_usize`. + /// + /// This function is mainly provided to allow the caller to get a + /// readiness value from an `AtomicUsize`. + pub(crate) fn from_usize(val: usize) -> Ready { + Ready(val & Ready::ALL.as_usize()) + } + + /// Returns a `usize` representation of the `Ready` value. + /// + /// This function is mainly provided to allow the caller to store a + /// readiness value in an `AtomicUsize`. + pub(crate) fn as_usize(self) -> usize { + self.0 + } +} + +cfg_io_readiness! { + impl Ready { + pub(crate) fn from_interest(interest: mio::Interest) -> Ready { + let mut ready = Ready::EMPTY; + + if interest.is_readable() { + ready |= Ready::READABLE; + ready |= Ready::READ_CLOSED; + } + + if interest.is_writable() { + ready |= Ready::WRITABLE; + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + pub(crate) fn intersection(self, interest: mio::Interest) -> Ready { + Ready(self.0 & Ready::from_interest(interest).0) + } + + pub(crate) fn satisfies(self, interest: mio::Interest) -> bool { + self.0 & Ready::from_interest(interest).0 != 0 + } + } +} + +impl<T: Into<Ready>> ops::BitOr<T> for Ready { + type Output = Ready; + + #[inline] + fn bitor(self, other: T) -> Ready { + Ready(self.0 | other.into().0) + } +} + +impl<T: Into<Ready>> ops::BitOrAssign<T> for Ready { + #[inline] + fn bitor_assign(&mut self, other: T) { + self.0 |= other.into().0; + } +} + +impl<T: Into<Ready>> ops::BitAnd<T> for Ready { + type Output = Ready; + + #[inline] + fn bitand(self, other: T) -> Ready { + Ready(self.0 & other.into().0) + } +} + +impl<T: Into<Ready>> ops::Sub<T> for Ready { + type Output = Ready; + + #[inline] + fn sub(self, other: T) -> Ready { + Ready(self.0 & !other.into().0) + } +} + +impl fmt::Debug for Ready { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Ready") + .field("is_readable", &self.is_readable()) + .field("is_writable", &self.is_writable()) + .field("is_read_closed", &self.is_read_closed()) + .field("is_write_closed", &self.is_write_closed()) + .finish() + } +} diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs index 7f6446e..b1354a0 100644 --- a/src/io/driver/scheduled_io.rs +++ b/src/io/driver/scheduled_io.rs @@ -1,47 +1,109 @@ -use crate::loom::future::AtomicWaker; +use super::{Direction, Ready, ReadyEvent, Tick}; use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; use crate::util::bit; -use crate::util::slab::{Address, Entry, Generation}; +use crate::util::slab::Entry; -use std::sync::atomic::Ordering::{AcqRel, Acquire, SeqCst}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::{Context, Poll, Waker}; +cfg_io_readiness! { + use crate::util::linked_list::{self, LinkedList}; + + use std::cell::UnsafeCell; + use std::future::Future; + use std::marker::PhantomPinned; + use std::pin::Pin; + use std::ptr::NonNull; +} + +/// Stored in the I/O driver resource slab. #[derive(Debug)] pub(crate) struct ScheduledIo { + /// Packs the resource's readiness with the resource's generation. readiness: AtomicUsize, - pub(crate) reader: AtomicWaker, - pub(crate) writer: AtomicWaker, + + waiters: Mutex<Waiters>, } -const PACK: bit::Pack = bit::Pack::most_significant(Generation::WIDTH); +cfg_io_readiness! { + type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; +} -impl Entry for ScheduledIo { - fn generation(&self) -> Generation { - unpack_generation(self.readiness.load(SeqCst)) +#[derive(Debug, Default)] +struct Waiters { + #[cfg(feature = "net")] + /// List of all current waiters + list: WaitList, + + /// Waker used for AsyncRead + reader: Option<Waker>, + + /// Waker used for AsyncWrite + writer: Option<Waker>, +} + +cfg_io_readiness! { + #[derive(Debug)] + struct Waiter { + pointers: linked_list::Pointers<Waiter>, + + /// The waker for this task + waker: Option<Waker>, + + /// The interest this waiter is waiting on + interest: mio::Interest, + + is_ready: bool, + + /// Should never be `!Unpin` + _p: PhantomPinned, } - fn reset(&self, generation: Generation) -> bool { - let mut current = self.readiness.load(Acquire); + /// Future returned by `readiness()` + struct Readiness<'a> { + scheduled_io: &'a ScheduledIo, - loop { - if unpack_generation(current) != generation { - return false; - } + state: State, - let next = PACK.pack(generation.next().to_usize(), 0); + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, + } - match self - .readiness - .compare_exchange(current, next, AcqRel, Acquire) - { - Ok(_) => break, - Err(actual) => current = actual, - } - } + enum State { + Init, + Waiting, + Done, + } +} + +// The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness. +// +// | reserved | generation | driver tick | readinesss | +// |----------+------------+--------------+------------| +// | 1 bit | 7 bits + 8 bits + 16 bits | + +const READINESS: bit::Pack = bit::Pack::least_significant(16); + +const TICK: bit::Pack = READINESS.then(8); - drop(self.reader.take_waker()); - drop(self.writer.take_waker()); +const GENERATION: bit::Pack = TICK.then(7); - true +#[test] +fn test_generations_assert_same() { + assert_eq!(super::GENERATION, GENERATION); +} + +// ===== impl ScheduledIo ===== + +impl Entry for ScheduledIo { + fn reset(&self) { + let state = self.readiness.load(Acquire); + + let generation = GENERATION.unpack(state); + let next = GENERATION.pack_lossy(generation + 1, 0); + + self.readiness.store(next, Release); } } @@ -49,31 +111,14 @@ impl Default for ScheduledIo { fn default() -> ScheduledIo { ScheduledIo { readiness: AtomicUsize::new(0), - reader: AtomicWaker::new(), - writer: AtomicWaker::new(), + waiters: Mutex::new(Default::default()), } } } impl ScheduledIo { - #[cfg(all(test, loom))] - /// Returns the current readiness value of this `ScheduledIo`, if the - /// provided `token` is still a valid access. - /// - /// # Returns - /// - /// If the given token's generation no longer matches the `ScheduledIo`'s - /// generation, then the corresponding IO resource has been removed and - /// replaced with a new resource. In that case, this method returns `None`. - /// Otherwise, this returns the current readiness. - pub(crate) fn get_readiness(&self, address: Address) -> Option<usize> { - let ready = self.readiness.load(Acquire); - - if unpack_generation(ready) != address.generation() { - return None; - } - - Some(ready & !PACK.mask()) + pub(crate) fn generation(&self) -> usize { + GENERATION.unpack(self.readiness.load(Acquire)) } /// Sets the readiness on this `ScheduledIo` by invoking the given closure on @@ -81,6 +126,8 @@ impl ScheduledIo { /// /// # Arguments /// - `token`: the token for this `ScheduledIo`. + /// - `tick`: whether setting the tick or trying to clear readiness for a + /// specific tick. /// - `f`: a closure returning a new readiness value given the previous /// readiness. /// @@ -90,52 +137,354 @@ impl ScheduledIo { /// generation, then the corresponding IO resource has been removed and /// replaced with a new resource. In that case, this method returns `Err`. /// Otherwise, this returns the previous readiness. - pub(crate) fn set_readiness( + pub(super) fn set_readiness( &self, - address: Address, - f: impl Fn(usize) -> usize, - ) -> Result<usize, ()> { - let generation = address.generation(); - + token: Option<usize>, + tick: Tick, + f: impl Fn(Ready) -> Ready, + ) -> Result<(), ()> { let mut current = self.readiness.load(Acquire); loop { - // Check that the generation for this access is still the current - // one. - if unpack_generation(current) != generation { - return Err(()); + let current_generation = GENERATION.unpack(current); + + if let Some(token) = token { + // Check that the generation for this access is still the + // current one. + if GENERATION.unpack(token) != current_generation { + return Err(()); + } } - // Mask out the generation bits so that the modifying function - // doesn't see them. - let current_readiness = current & mio::Ready::all().as_usize(); + + // Mask out the tick/generation bits so that the modifying + // function doesn't see them. + let current_readiness = Ready::from_usize(current); let new = f(current_readiness); - debug_assert!( - new <= !PACK.max_value(), - "new readiness value would overwrite generation bits!" - ); - - match self.readiness.compare_exchange( - current, - PACK.pack(generation.to_usize(), new), - AcqRel, - Acquire, - ) { - Ok(_) => return Ok(current), + let packed = match tick { + Tick::Set(t) => TICK.pack(t as usize, new.as_usize()), + Tick::Clear(t) => { + if TICK.unpack(current) as u8 != t { + // Trying to clear readiness with an old event! + return Err(()); + } + + TICK.pack(t as usize, new.as_usize()) + } + }; + + let next = GENERATION.pack(current_generation, packed); + + match self + .readiness + .compare_exchange(current, next, AcqRel, Acquire) + { + Ok(_) => return Ok(()), // we lost the race, retry! Err(actual) => current = actual, } } } + + /// Notifies all pending waiters that have registered interest in `ready`. + /// + /// There may be many waiters to notify. Waking the pending task **must** be + /// done from outside of the lock otherwise there is a potential for a + /// deadlock. + /// + /// A stack array of wakers is created and filled with wakers to notify, the + /// lock is released, and the wakers are notified. Because there may be more + /// than 32 wakers to notify, if the stack array fills up, the lock is + /// released, the array is cleared, and the iteration continues. + pub(super) fn wake(&self, ready: Ready) { + const NUM_WAKERS: usize = 32; + + let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default(); + let mut curr = 0; + + let mut waiters = self.waiters.lock(); + + // check for AsyncRead slot + if ready.is_readable() { + if let Some(waker) = waiters.reader.take() { + wakers[curr] = Some(waker); + curr += 1; + } + } + + // check for AsyncWrite slot + if ready.is_writable() { + if let Some(waker) = waiters.writer.take() { + wakers[curr] = Some(waker); + curr += 1; + } + } + + #[cfg(feature = "net")] + 'outer: loop { + let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest)); + + while curr < NUM_WAKERS { + match iter.next() { + Some(waiter) => { + let waiter = unsafe { &mut *waiter.as_ptr() }; + + if let Some(waker) = waiter.waker.take() { + waiter.is_ready = true; + wakers[curr] = Some(waker); + curr += 1; + } + } + None => { + break 'outer; + } + } + } + + drop(waiters); + + for waker in wakers.iter_mut().take(curr) { + waker.take().unwrap().wake(); + } + + curr = 0; + + // Acquire the lock again. + waiters = self.waiters.lock(); + } + + // Release the lock before notifying + drop(waiters); + + for waker in wakers.iter_mut().take(curr) { + waker.take().unwrap().wake(); + } + } + + /// Poll version of checking readiness for a certain direction. + /// + /// These are to support `AsyncRead` and `AsyncWrite` polling methods, + /// which cannot use the `async fn` version. This uses reserved reader + /// and writer slots. + pub(in crate::io) fn poll_readiness( + &self, + cx: &mut Context<'_>, + direction: Direction, + ) -> Poll<ReadyEvent> { + let curr = self.readiness.load(Acquire); + + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + + if ready.is_empty() { + // Update the task info + let mut waiters = self.waiters.lock(); + let slot = match direction { + Direction::Read => &mut waiters.reader, + Direction::Write => &mut waiters.writer, + }; + *slot = Some(cx.waker().clone()); + + // Try again, in case the readiness was changed while we were + // taking the waiters lock + let curr = self.readiness.load(Acquire); + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + if ready.is_empty() { + Poll::Pending + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } + + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + // This consumes the current readiness state **except** for closed + // states. Closed states are excluded because they are final states. + let mask_no_closed = event.ready - Ready::READ_CLOSED - Ready::WRITE_CLOSED; + + // result isn't important + let _ = self.set_readiness(None, Tick::Clear(event.tick), |curr| curr - mask_no_closed); + } } impl Drop for ScheduledIo { fn drop(&mut self) { - self.writer.wake(); - self.reader.wake(); + self.wake(Ready::ALL); } } -fn unpack_generation(src: usize) -> Generation { - Generation::new(PACK.unpack(src)) +unsafe impl Send for ScheduledIo {} +unsafe impl Sync for ScheduledIo {} + +cfg_io_readiness! { + impl ScheduledIo { + /// An async version of `poll_readiness` which uses a linked list of wakers + pub(crate) async fn readiness(&self, interest: mio::Interest) -> ReadyEvent { + self.readiness_fut(interest).await + } + + // This is in a separate function so that the borrow checker doesn't think + // we are borrowing the `UnsafeCell` possibly over await boundaries. + // + // Go figure. + fn readiness_fut(&self, interest: mio::Interest) -> Readiness<'_> { + Readiness { + scheduled_io: self, + state: State::Init, + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + is_ready: false, + interest, + _p: PhantomPinned, + }), + } + } + } + + unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } + } + + // ===== impl Readiness ===== + + impl Future for Readiness<'_> { + type Output = ReadyEvent; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::sync::atomic::Ordering::SeqCst; + + let (scheduled_io, state, waiter) = unsafe { + let me = self.get_unchecked_mut(); + (&me.scheduled_io, &mut me.state, &me.waiter) + }; + + loop { + match *state { + State::Init => { + // Optimistically check existing readiness + let curr = scheduled_io.readiness.load(SeqCst); + let ready = Ready::from_usize(READINESS.unpack(curr)); + + // Safety: `waiter.interest` never changes + let interest = unsafe { (*waiter.get()).interest }; + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { ready, tick }); + } + + // Wasn't ready, take the lock (and check again while locked). + let mut waiters = scheduled_io.waiters.lock(); + + let curr = scheduled_io.readiness.load(SeqCst); + let ready = Ready::from_usize(READINESS.unpack(curr)); + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { ready, tick }); + } + + // Not ready even after locked, insert into list... + + // Safety: called while locked + unsafe { + (*waiter.get()).waker = Some(cx.waker().clone()); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters + .list + .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + *state = State::Waiting; + } + State::Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = scheduled_io.waiters.lock(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.is_ready { + // Our waker has been notified. + *state = State::Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + State::Done => { + let tick = TICK.unpack(scheduled_io.readiness.load(Acquire)) as u8; + + // Safety: State::Done means it is no longer shared + let w = unsafe { &mut *waiter.get() }; + + return Poll::Ready(ReadyEvent { + tick, + ready: Ready::from_interest(w.interest), + }); + } + } + } + } + } + + impl Drop for Readiness<'_> { + fn drop(&mut self) { + let mut waiters = self.scheduled_io.waiters.lock(); + + // Safety: `waiter` is only ever stored in `waiters` + unsafe { + waiters + .list + .remove(NonNull::new_unchecked(self.waiter.get())) + }; + } + } + + unsafe impl Send for Readiness<'_> {} + unsafe impl Sync for Readiness<'_> {} } diff --git a/src/io/mod.rs b/src/io/mod.rs index 7b00556..9191bbc 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -162,8 +162,8 @@ //! //! # `std` re-exports //! -//! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported -//! from `std::io` for ease of use. +//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are +//! re-exported from `std::io` for ease of use. //! //! [`AsyncRead`]: trait@AsyncRead //! [`AsyncWrite`]: trait@AsyncWrite @@ -176,6 +176,7 @@ //! [`ErrorKind`]: enum@ErrorKind //! [`Result`]: type@Result //! [`Read`]: std::io::Read +//! [`SeekFrom`]: enum@SeekFrom //! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html //! [`Stream`]: crate::stream::Stream //! [`Write`]: std::io::Write @@ -187,7 +188,6 @@ mod async_buf_read; pub use self::async_buf_read::AsyncBufRead; mod async_read; - pub use self::async_read::AsyncRead; mod async_seek; @@ -196,17 +196,27 @@ pub use self::async_seek::AsyncSeek; mod async_write; pub use self::async_write::AsyncWrite; +mod read_buf; +pub use self::read_buf::ReadBuf; + +// Re-export some types from `std::io` so that users don't have to deal +// with conflicts when `use`ing `tokio::io` and `std::io`. +#[doc(no_inline)] +pub use std::io::{Error, ErrorKind, Result, SeekFrom}; + cfg_io_driver! { pub(crate) mod driver; mod poll_evented; - pub use poll_evented::PollEvented; + #[cfg(not(loom))] + pub(crate) use poll_evented::PollEvented; mod registration; - pub use registration::Registration; } cfg_io_std! { + mod stdio_common; + mod stderr; pub use stderr::{stderr, Stderr}; @@ -222,21 +232,11 @@ cfg_io_util! { pub use split::{split, ReadHalf, WriteHalf}; pub(crate) mod seek; - pub use self::seek::Seek; - pub(crate) mod util; pub use util::{ - copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, - BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, + copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take, }; - - cfg_stream! { - pub use util::{stream_reader, StreamReader}; - } - - // Re-export io::Error so that users don't have to deal with conflicts when - // `use`ing `tokio::io` and `std::io`. - pub use std::io::{Error, ErrorKind, Result}; } cfg_not_io_util! { @@ -249,7 +249,7 @@ cfg_io_blocking! { /// Types in this module can be mocked out in tests. mod sys { // TODO: don't rename - pub(crate) use crate::runtime::spawn_blocking as run; - pub(crate) use crate::task::JoinHandle as Blocking; + pub(crate) use crate::blocking::spawn_blocking as run; + pub(crate) use crate::blocking::JoinHandle as Blocking; } } diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs index 5295bd7..66a2634 100644 --- a/src/io/poll_evented.rs +++ b/src/io/poll_evented.rs @@ -1,13 +1,12 @@ -use crate::io::driver::platform; -use crate::io::{AsyncRead, AsyncWrite, Registration}; +use crate::io::driver::{Direction, Handle, ReadyEvent}; +use crate::io::registration::Registration; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use mio::event::Evented; +use mio::event::Source; use std::fmt; use std::io::{self, Read, Write}; use std::marker::Unpin; use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::Relaxed; use std::task::{Context, Poll}; cfg_io_driver! { @@ -53,37 +52,6 @@ cfg_io_driver! { /// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and /// [`clear_read_ready`]. /// - /// ```rust - /// use tokio::io::PollEvented; - /// - /// use futures::ready; - /// use mio::Ready; - /// use mio::net::{TcpStream, TcpListener}; - /// use std::io; - /// use std::task::{Context, Poll}; - /// - /// struct MyListener { - /// poll_evented: PollEvented<TcpListener>, - /// } - /// - /// impl MyListener { - /// pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<TcpStream, io::Error>> { - /// let ready = Ready::readable(); - /// - /// ready!(self.poll_evented.poll_read_ready(cx, ready))?; - /// - /// match self.poll_evented.get_ref().accept() { - /// Ok((socket, _)) => Poll::Ready(Ok(socket)), - /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - /// self.poll_evented.clear_read_ready(cx, ready)?; - /// Poll::Pending - /// } - /// Err(e) => Poll::Ready(Err(e)), - /// } - /// } - /// } - /// ``` - /// /// ## Platform-specific events /// /// `PollEvented` also allows receiving platform-specific `mio::Ready` events. @@ -101,70 +69,15 @@ cfg_io_driver! { /// [`clear_write_ready`]: method@Self::clear_write_ready /// [`poll_read_ready`]: method@Self::poll_read_ready /// [`poll_write_ready`]: method@Self::poll_write_ready - pub struct PollEvented<E: Evented> { + pub(crate) struct PollEvented<E: Source> { io: Option<E>, - inner: Inner, + registration: Registration, } } -struct Inner { - registration: Registration, - - /// Currently visible read readiness - read_readiness: AtomicUsize, - - /// Currently visible write readiness - write_readiness: AtomicUsize, -} - // ===== impl PollEvented ===== -macro_rules! poll_ready { - ($me:expr, $mask:expr, $cache:ident, $take:ident, $poll:expr) => {{ - // Load cached & encoded readiness. - let mut cached = $me.inner.$cache.load(Relaxed); - let mask = $mask | platform::hup() | platform::error(); - - // See if the current readiness matches any bits. - let mut ret = mio::Ready::from_usize(cached) & $mask; - - if ret.is_empty() { - // Readiness does not match, consume the registration's readiness - // stream. This happens in a loop to ensure that the stream gets - // drained. - loop { - let ready = match $poll? { - Poll::Ready(v) => v, - Poll::Pending => return Poll::Pending, - }; - cached |= ready.as_usize(); - - // Update the cache store - $me.inner.$cache.store(cached, Relaxed); - - ret |= ready & mask; - - if !ret.is_empty() { - return Poll::Ready(Ok(ret)); - } - } - } else { - // Check what's new with the registration stream. This will not - // request to be notified - if let Some(ready) = $me.inner.registration.$take()? { - cached |= ready.as_usize(); - $me.inner.$cache.store(cached, Relaxed); - } - - Poll::Ready(Ok(mio::Ready::from_usize(cached))) - } - }}; -} - -impl<E> PollEvented<E> -where - E: Evented, -{ +impl<E: Source> PollEvented<E> { /// Creates a new `PollEvented` associated with the default reactor. /// /// # Panics @@ -173,71 +86,57 @@ where /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new(io: E) -> io::Result<Self> { - PollEvented::new_with_ready(io, mio::Ready::all()) + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new(io: E) -> io::Result<Self> { + PollEvented::new_with_interest(io, mio::Interest::READABLE | mio::Interest::WRITABLE) } - /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready` - /// state. `new_with_ready` should be used over `new` when you need control over the readiness + /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Interest` + /// state. `new_with_interest` should be used over `new` when you need control over the readiness /// state, such as when a file descriptor only allows reads. This does not add `hup` or `error` /// so if you are interested in those states, you will need to add them to the readiness state /// passed to this function. /// - /// An example to listen to read only - /// - /// ```rust - /// ##[cfg(unix)] - /// mio::Ready::from_usize( - /// mio::Ready::readable().as_usize() - /// | mio::unix::UnixReady::error().as_usize() - /// | mio::unix::UnixReady::hup().as_usize() - /// ); - /// ``` - /// /// # Panics /// /// This function panics if thread-local runtime is not set. /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result<Self> { - let registration = Registration::new_with_ready(&io, ready)?; + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new_with_interest(io: E, interest: mio::Interest) -> io::Result<Self> { + Self::new_with_interest_and_handle(io, interest, Handle::current()) + } + + pub(crate) fn new_with_interest_and_handle( + mut io: E, + interest: mio::Interest, + handle: Handle, + ) -> io::Result<Self> { + let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?; Ok(Self { io: Some(io), - inner: Inner { - registration, - read_readiness: AtomicUsize::new(0), - write_readiness: AtomicUsize::new(0), - }, + registration, }) } /// Returns a shared reference to the underlying I/O object this readiness /// stream is wrapping. - pub fn get_ref(&self) -> &E { + #[cfg(any(feature = "net", feature = "process", feature = "signal"))] + pub(crate) fn get_ref(&self) -> &E { self.io.as_ref().unwrap() } /// Returns a mutable reference to the underlying I/O object this readiness /// stream is wrapping. - pub fn get_mut(&mut self) -> &mut E { + pub(crate) fn get_mut(&mut self) -> &mut E { self.io.as_mut().unwrap() } - /// Consumes self, returning the inner I/O object - /// - /// This function will deregister the I/O resource from the reactor before - /// returning. If the deregistration operation fails, an error is returned. - /// - /// Note that deregistering does not guarantee that the I/O resource can be - /// registered with a different reactor. Some I/O resource types can only be - /// associated with a single reactor instance for their lifetime. - pub fn into_inner(mut self) -> io::Result<E> { - let io = self.io.take().unwrap(); - self.inner.registration.deregister(&io)?; - Ok(io) + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + self.registration.clear_readiness(event); } /// Checks the I/O resource's read readiness state. @@ -266,51 +165,8 @@ where /// /// This method may not be called concurrently. It takes `&self` to allow /// calling it concurrently with `poll_write_ready`. - pub fn poll_read_ready( - &self, - cx: &mut Context<'_>, - mask: mio::Ready, - ) -> Poll<io::Result<mio::Ready>> { - assert!(!mask.is_writable(), "cannot poll for write readiness"); - poll_ready!( - self, - mask, - read_readiness, - take_read_ready, - self.inner.registration.poll_read_ready(cx) - ) - } - - /// Clears the I/O resource's read readiness state and registers the current - /// task to be notified once a read readiness event is received. - /// - /// After calling this function, `poll_read_ready` will return - /// `Poll::Pending` until a new read readiness event has been received. - /// - /// The `mask` argument specifies the readiness bits to clear. This may not - /// include `writable` or `hup`. - /// - /// # Panics - /// - /// This function panics if: - /// - /// * `ready` includes writable or HUP - /// * called from outside of a task context. - pub fn clear_read_ready(&self, cx: &mut Context<'_>, ready: mio::Ready) -> io::Result<()> { - // Cannot clear write readiness - assert!(!ready.is_writable(), "cannot clear write readiness"); - assert!(!platform::is_hup(ready), "cannot clear HUP readiness"); - - self.inner - .read_readiness - .fetch_and(!ready.as_usize(), Relaxed); - - if self.poll_read_ready(cx, ready)?.is_ready() { - // Notify the current task - cx.waker().wake_by_ref(); - } - - Ok(()) + pub(crate) fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.registration.poll_readiness(cx, Direction::Read) } /// Checks the I/O resource's write readiness state. @@ -337,100 +193,95 @@ where /// /// This method may not be called concurrently. It takes `&self` to allow /// calling it concurrently with `poll_read_ready`. - pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - poll_ready!( - self, - mio::Ready::writable(), - write_readiness, - take_write_ready, - self.inner.registration.poll_write_ready(cx) - ) + pub(crate) fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.registration.poll_readiness(cx, Direction::Write) } +} - /// Resets the I/O resource's write readiness state and registers the current - /// task to be notified once a write readiness event is received. - /// - /// This only clears writable readiness. HUP (on platforms that support HUP) - /// cannot be cleared as it is a final state. - /// - /// After calling this function, `poll_write_ready(Ready::writable())` will - /// return `NotReady` until a new write readiness event has been received. - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn clear_write_ready(&self, cx: &mut Context<'_>) -> io::Result<()> { - let ready = mio::Ready::writable(); +cfg_io_readiness! { + impl<E: Source> PollEvented<E> { + pub(crate) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> { + self.registration.readiness(interest).await + } - self.inner - .write_readiness - .fetch_and(!ready.as_usize(), Relaxed); + pub(crate) async fn async_io<F, R>(&self, interest: mio::Interest, mut op: F) -> io::Result<R> + where + F: FnMut(&E) -> io::Result<R>, + { + loop { + let event = self.readiness(interest).await?; - if self.poll_write_ready(cx)?.is_ready() { - // Notify the current task - cx.waker().wake_by_ref(); + match op(self.get_ref()) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.clear_readiness(event); + } + x => return x, + } + } } - - Ok(()) } } // ===== Read / Write impls ===== -impl<E> AsyncRead for PollEvented<E> -where - E: Evented + Read + Unpin, -{ +impl<E: Source + Read + Unpin> AsyncRead for PollEvented<E> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.poll_read_ready(cx, mio::Ready::readable()))?; - - let r = (*self).get_mut().read(buf); + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + loop { + let ev = ready!(self.poll_read_ready(cx))?; + + // We can't assume the `Read` won't look at the read buffer, + // so we have to force initialization here. + let r = (*self).get_mut().read(buf.initialize_unfilled()); + + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - if is_wouldblock(&r) { - self.clear_read_ready(cx, mio::Ready::readable())?; - return Poll::Pending; + return Poll::Ready(r.map(|n| { + buf.advance(n); + })); } - - Poll::Ready(r) } } -impl<E> AsyncWrite for PollEvented<E> -where - E: Evented + Write + Unpin, -{ +impl<E: Source + Write + Unpin> AsyncWrite for PollEvented<E> { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - ready!(self.poll_write_ready(cx))?; + loop { + let ev = ready!(self.poll_write_ready(cx))?; - let r = (*self).get_mut().write(buf); + let r = (*self).get_mut().write(buf); - if is_wouldblock(&r) { - self.clear_write_ready(cx)?; - return Poll::Pending; - } + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - Poll::Ready(r) + return Poll::Ready(r); + } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - ready!(self.poll_write_ready(cx))?; + loop { + let ev = ready!(self.poll_write_ready(cx))?; - let r = (*self).get_mut().flush(); + let r = (*self).get_mut().flush(); - if is_wouldblock(&r) { - self.clear_write_ready(cx)?; - return Poll::Pending; - } + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - Poll::Ready(r) + return Poll::Ready(r); + } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { @@ -445,17 +296,17 @@ fn is_wouldblock<T>(r: &io::Result<T>) -> bool { } } -impl<E: Evented + fmt::Debug> fmt::Debug for PollEvented<E> { +impl<E: Source + fmt::Debug> fmt::Debug for PollEvented<E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PollEvented").field("io", &self.io).finish() } } -impl<E: Evented> Drop for PollEvented<E> { +impl<E: Source> Drop for PollEvented<E> { fn drop(&mut self) { - if let Some(io) = self.io.take() { + if let Some(mut io) = self.io.take() { // Ignore errors - let _ = self.inner.registration.deregister(&io); + let _ = self.registration.deregister(&mut io); } } } diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs new file mode 100644 index 0000000..b64d95c --- /dev/null +++ b/src/io/read_buf.rs @@ -0,0 +1,261 @@ +// This lint claims ugly casting is somehow safer than transmute, but there's +// no evidence that is the case. Shush. +#![allow(clippy::transmute_ptr_to_ptr)] + +use std::fmt; +use std::mem::{self, MaybeUninit}; + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that is fully uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +pub struct ReadBuf<'a> { + buf: &'a mut [MaybeUninit<u8>], + filled: usize, + initialized: usize, +} + +impl<'a> ReadBuf<'a> { + /// Creates a new `ReadBuf` from a fully initialized buffer. + #[inline] + pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> { + let initialized = buf.len(); + let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buf) }; + ReadBuf { + buf, + filled: 0, + initialized, + } + } + + /// Creates a new `ReadBuf` from a fully uninitialized buffer. + /// + /// Use `assume_init` if part of the buffer is known to be already inintialized. + #[inline] + pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> { + ReadBuf { + buf, + filled: 0, + initialized: 0, + } + } + + /// Returns the total capacity of the buffer. + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + /// Returns a shared reference to the filled portion of the buffer. + #[inline] + pub fn filled(&self) -> &[u8] { + let slice = &self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) } + } + + /// Returns a mutable reference to the filled portion of the buffer. + #[inline] + pub fn filled_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns a new `ReadBuf` comprised of the unfilled section up to `n`. + #[inline] + pub fn take(&mut self, n: usize) -> ReadBuf<'_> { + let max = std::cmp::min(self.remaining(), n); + // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } + } + + /// Returns a shared reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized(&self) -> &[u8] { + let slice = &self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) } + } + + /// Returns a mutable reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully + /// initialized. + /// + /// # Safety + /// + /// The caller must not de-initialize portions of the buffer that have already been initialized. + #[inline] + pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit<u8>] { + &mut self.buf[self.filled..] + } + + /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized. + /// + /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after + /// the first use. + #[inline] + pub fn initialize_unfilled(&mut self) -> &mut [u8] { + self.initialize_unfilled_to(self.remaining()) + } + + /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is + /// fully initialized. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `n`. + #[inline] + pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] { + assert!(self.remaining() >= n, "n overflows remaining"); + + // This can't overflow, otherwise the assert above would have failed. + let end = self.filled + n; + + if self.initialized < end { + unsafe { + self.buf[self.initialized..end] + .as_mut_ptr() + .write_bytes(0, end - self.initialized); + } + self.initialized = end; + } + + let slice = &mut self.buf[self.filled..end]; + // safety: just above, we checked that the end of the buf has + // been initialized to some value. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + pub fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + /// Clears the buffer, resetting the filled region to empty. + /// + /// The number of initialized bytes is not changed, and the contents of the buffer are not modified. + #[inline] + pub fn clear(&mut self) { + self.filled = 0; + } + + /// Advances the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the initialized region. + #[inline] + pub fn advance(&mut self, n: usize) { + let new = self.filled.checked_add(n).expect("filled overflow"); + self.set_filled(new); + } + + /// Sets the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for + /// example, by a `AsyncRead` implementation that compresses data in-place). + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the intialized region. + #[inline] + pub fn set_filled(&mut self, n: usize) { + assert!( + n <= self.initialized, + "filled must not become larger than initialized" + ); + self.filled = n; + } + + /// Asserts that the first `n` unfilled bytes of the buffer are initialized. + /// + /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer + /// bytes than are already known to be initialized. + /// + /// # Safety + /// + /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized. + #[inline] + pub unsafe fn assume_init(&mut self, n: usize) { + let new = self.filled + n; + if new > self.initialized { + self.initialized = new; + } + } + + /// Appends data to the buffer, advancing the written position and possibly also the initialized position. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `buf.len()`. + #[inline] + pub fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf[self.filled..end] + .as_mut_ptr() + .cast::<u8>() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.initialized < end { + self.initialized = end; + } + self.filled = end; + } +} + +impl fmt::Debug for ReadBuf<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("initialized", &self.initialized) + .field("capacity", &self.capacity()) + .finish() + } +} diff --git a/src/io/registration.rs b/src/io/registration.rs index 77fe6db..ce6cffd 100644 --- a/src/io/registration.rs +++ b/src/io/registration.rs @@ -1,7 +1,7 @@ -use crate::io::driver::{platform, Direction, Handle}; -use crate::util::slab::Address; +use crate::io::driver::{Direction, Handle, ReadyEvent, ScheduledIo}; +use crate::util::slab; -use mio::{self, Evented}; +use mio::event::Source; use std::io; use std::task::{Context, Poll}; @@ -38,74 +38,38 @@ cfg_io_driver! { /// [`poll_read_ready`]: method@Self::poll_read_ready` /// [`poll_write_ready`]: method@Self::poll_write_ready` #[derive(Debug)] - pub struct Registration { + pub(crate) struct Registration { + /// Handle to the associated driver. handle: Handle, - address: Address, + + /// Reference to state stored by the driver. + shared: slab::Ref<ScheduledIo>, } } +unsafe impl Send for Registration {} +unsafe impl Sync for Registration {} + // ===== impl Registration ===== impl Registration { - /// Registers the I/O resource with the default reactor. - /// - /// # Return - /// - /// - `Ok` if the registration happened successfully - /// - `Err` if an error was encountered during registration - /// - /// - /// # Panics - /// - /// This function panics if thread-local runtime is not set. - /// - /// The runtime is usually set implicitly when this function is called - /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new<T>(io: &T) -> io::Result<Registration> - where - T: Evented, - { - Registration::new_with_ready(io, mio::Ready::all()) - } - - /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state. - /// `new_with_ready` should be used over `new` when you need control over the readiness state, + /// Registers the I/O resource with the default reactor, for a specific `mio::Interest`. + /// `new_with_interest` should be used over `new` when you need control over the readiness state, /// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if /// you are interested in those states, you will need to add them to the readiness state passed /// to this function. /// - /// An example to listen to read only - /// - /// ```rust - /// ##[cfg(unix)] - /// mio::Ready::from_usize( - /// mio::Ready::readable().as_usize() - /// | mio::unix::UnixReady::error().as_usize() - /// | mio::unix::UnixReady::hup().as_usize() - /// ); - /// ``` - /// /// # Return /// /// - `Ok` if the registration happened successfully /// - `Err` if an error was encountered during registration - /// - /// - /// # Panics - /// - /// This function panics if thread-local runtime is not set. - /// - /// The runtime is usually set implicitly when this function is called - /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new_with_ready<T>(io: &T, ready: mio::Ready) -> io::Result<Registration> - where - T: Evented, - { - let handle = Handle::current(); - let address = if let Some(inner) = handle.inner() { - inner.add_source(io, ready)? + pub(crate) fn new_with_interest_and_handle( + io: &mut impl Source, + interest: mio::Interest, + handle: Handle, + ) -> io::Result<Registration> { + let shared = if let Some(inner) = handle.inner() { + inner.add_source(io, interest)? } else { return Err(io::Error::new( io::ErrorKind::Other, @@ -113,7 +77,7 @@ impl Registration { )); }; - Ok(Registration { handle, address }) + Ok(Registration { handle, shared }) } /// Deregisters the I/O resource from the reactor it is associated with. @@ -132,10 +96,7 @@ impl Registration { /// no longer result in notifications getting sent for this registration. /// /// `Err` is returned if an error is encountered. - pub fn deregister<T>(&mut self, io: &T) -> io::Result<()> - where - T: Evented, - { + pub(super) fn deregister(&mut self, io: &mut impl Source) -> io::Result<()> { let inner = match self.handle.inner() { Some(inner) => inner, None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")), @@ -143,198 +104,47 @@ impl Registration { inner.deregister_source(io) } - /// Polls for events on the I/O resource's read readiness stream. - /// - /// If the I/O resource receives a new read readiness event since the last - /// call to `poll_read_ready`, it is returned. If it has not, the current - /// task is notified once a new event is received. - /// - /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned, - /// the function will always return `Ready(HUP)`. This should be treated as - /// the end of the readiness stream. - /// - /// # Return value - /// - /// There are several possible return values: - /// - /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received - /// a new readiness event. The readiness value is included. - /// - /// * `Poll::Pending` means that no new readiness events have been received - /// since the last call to `poll_read_ready`. - /// - /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This could represent a permanent internal error for example. - /// - /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| { - coop.made_progress(); - e - })?; - match v { - Some(v) => { - coop.made_progress(); - Poll::Ready(Ok(v)) - } - None => Poll::Pending, - } - } - - /// Consume any pending read readiness event. - /// - /// This function is identical to [`poll_read_ready`] **except** that it - /// will not notify the current task when a new event is received. As such, - /// it is safe to call this function from outside of a task context. - /// - /// [`poll_read_ready`]: method@Self::poll_read_ready - pub fn take_read_ready(&self) -> io::Result<Option<mio::Ready>> { - self.poll_ready(Direction::Read, None) - } - - /// Polls for events on the I/O resource's write readiness stream. - /// - /// If the I/O resource receives a new write readiness event since the last - /// call to `poll_write_ready`, it is returned. If it has not, the current - /// task is notified once a new event is received. - /// - /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned, - /// the function will always return `Ready(HUP)`. This should be treated as - /// the end of the readiness stream. - /// - /// # Return value - /// - /// There are several possible return values: - /// - /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received - /// a new readiness event. The readiness value is included. - /// - /// * `Poll::Pending` means that no new readiness events have been received - /// since the last call to `poll_write_ready`. - /// - /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This could represent a permanent internal error for example. - /// - /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| { - coop.made_progress(); - e - })?; - match v { - Some(v) => { - coop.made_progress(); - Poll::Ready(Ok(v)) - } - None => Poll::Pending, - } - } - - /// Consumes any pending write readiness event. - /// - /// This function is identical to [`poll_write_ready`] **except** that it - /// will not notify the current task when a new event is received. As such, - /// it is safe to call this function from outside of a task context. - /// - /// [`poll_write_ready`]: method@Self::poll_write_ready - pub fn take_write_ready(&self) -> io::Result<Option<mio::Ready>> { - self.poll_ready(Direction::Write, None) + pub(super) fn clear_readiness(&self, event: ReadyEvent) { + self.shared.clear_readiness(event); } /// Polls for events on the I/O resource's `direction` readiness stream. /// /// If called with a task context, notify the task when a new event is /// received. - fn poll_ready( + pub(super) fn poll_readiness( &self, + cx: &mut Context<'_>, direction: Direction, - cx: Option<&mut Context<'_>>, - ) -> io::Result<Option<mio::Ready>> { - let inner = match self.handle.inner() { - Some(inner) => inner, - None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")), - }; - - // If the task should be notified about new events, ensure that it has - // been registered - if let Some(ref cx) = cx { - inner.register(self.address, direction, cx.waker().clone()) + ) -> Poll<io::Result<ReadyEvent>> { + if self.handle.inner().is_none() { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone"))); } - let mask = direction.mask(); - let mask_no_hup = (mask - platform::hup() - platform::error()).as_usize(); - - let sched = inner.io_dispatch.get(self.address).unwrap(); + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + let ev = ready!(self.shared.poll_readiness(cx, direction)); + coop.made_progress(); + Poll::Ready(Ok(ev)) + } +} - // This consumes the current readiness state **except** for HUP and - // error. HUP and error are excluded because a) they are final states - // and never transitition out and b) both the read AND the write - // directions need to be able to obvserve these states. - // - // # Platform-specific behavior - // - // HUP and error readiness are platform-specific. On epoll platforms, - // HUP has specific conditions that must be met by both peers of a - // connection in order to be triggered. - // - // On epoll platforms, `EPOLLERR` is signaled through - // `UnixReady::error()` and is important to be observable by both read - // AND write. A specific case that `EPOLLERR` occurs is when the read - // end of a pipe is closed. When this occurs, a peer blocked by - // writing to the pipe should be notified. - let curr_ready = sched - .set_readiness(self.address, |curr| curr & (!mask_no_hup)) - .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address)); +cfg_io_readiness! { + impl Registration { + pub(super) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> { + use std::future::Future; + use std::pin::Pin; - let mut ready = mask & mio::Ready::from_usize(curr_ready); + let fut = self.shared.readiness(interest); + pin!(fut); - if ready.is_empty() { - if let Some(cx) = cx { - // Update the task info - match direction { - Direction::Read => sched.reader.register_by_ref(cx.waker()), - Direction::Write => sched.writer.register_by_ref(cx.waker()), + crate::future::poll_fn(|cx| { + if self.handle.inner().is_none() { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone"))); } - // Try again - let curr_ready = sched - .set_readiness(self.address, |curr| curr & (!mask_no_hup)) - .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address)); - ready = mask & mio::Ready::from_usize(curr_ready); - } - } - - if ready.is_empty() { - Ok(None) - } else { - Ok(Some(ready)) + Pin::new(&mut fut).poll(cx).map(Ok) + }).await } } } - -unsafe impl Send for Registration {} -unsafe impl Sync for Registration {} - -impl Drop for Registration { - fn drop(&mut self) { - let inner = match self.handle.inner() { - Some(inner) => inner, - None => return, - }; - inner.drop_source(self.address); - } -} diff --git a/src/io/seek.rs b/src/io/seek.rs index e3b5bf6..e64205d 100644 --- a/src/io/seek.rs +++ b/src/io/seek.rs @@ -1,15 +1,23 @@ use crate::io::AsyncSeek; + +use pin_project_lite::pin_project; use std::future::Future; use std::io::{self, SeekFrom}; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -/// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Seek<'a, S: ?Sized> { - seek: &'a mut S, - pos: Option<SeekFrom>, +pin_project! { + /// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Seek<'a, S: ?Sized> { + seek: &'a mut S, + pos: Option<SeekFrom>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } pub(crate) fn seek<S>(seek: &mut S, pos: SeekFrom) -> Seek<'_, S> @@ -19,6 +27,7 @@ where Seek { seek, pos: Some(pos), + _pin: PhantomPinned, } } @@ -28,29 +37,21 @@ where { type Output = io::Result<u64>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); match me.pos { - Some(pos) => match Pin::new(&mut me.seek).start_seek(cx, pos) { - Poll::Ready(Ok(())) => { - me.pos = None; - Pin::new(&mut me.seek).poll_complete(cx) + Some(pos) => { + // ensure no seek in progress + ready!(Pin::new(&mut *me.seek).poll_complete(cx))?; + match Pin::new(&mut *me.seek).start_seek(*pos) { + Ok(()) => { + *me.pos = None; + Pin::new(&mut *me.seek).poll_complete(cx) + } + Err(e) => Poll::Ready(Err(e)), } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - }, - None => Pin::new(&mut me.seek).poll_complete(cx), + } + None => Pin::new(&mut *me.seek).poll_complete(cx), } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Seek<'_, PhantomPinned>>(); - } -} diff --git a/src/io/split.rs b/src/io/split.rs index 134b937..fd3273e 100644 --- a/src/io/split.rs +++ b/src/io/split.rs @@ -4,9 +4,8 @@ //! To restore this read/write object from its `split::ReadHalf` and //! `split::WriteHalf` use `unsplit`. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use bytes::{Buf, BufMut}; use std::cell::UnsafeCell; use std::fmt; use std::io; @@ -102,20 +101,11 @@ impl<T: AsyncRead> AsyncRead for ReadHalf<T> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { let mut inner = ready!(self.inner.poll_lock(cx)); inner.stream_pin().poll_read(cx, buf) } - - fn poll_read_buf<B: BufMut>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_read_buf(cx, buf) - } } impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { @@ -137,15 +127,6 @@ impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { let mut inner = ready!(self.inner.poll_lock(cx)); inner.stream_pin().poll_shutdown(cx) } - - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<Result<usize, io::Error>> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write_buf(cx, buf) - } } impl<T> Inner<T> { diff --git a/src/io/stderr.rs b/src/io/stderr.rs index 99607dc..2f624fb 100644 --- a/src/io/stderr.rs +++ b/src/io/stderr.rs @@ -1,4 +1,5 @@ use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; use crate::io::AsyncWrite; use std::io; @@ -35,7 +36,7 @@ cfg_io_std! { /// ``` #[derive(Debug)] pub struct Stderr { - std: Blocking<std::io::Stderr>, + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stderr>>, } /// Constructs a new handle to the standard error of the current process. @@ -59,7 +60,7 @@ cfg_io_std! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut stderr = io::stdout(); + /// let mut stderr = io::stderr(); /// stderr.write_all(b"Print some error here.").await?; /// Ok(()) /// } @@ -67,7 +68,7 @@ cfg_io_std! { pub fn stderr() -> Stderr { let std = io::stderr(); Stderr { - std: Blocking::new(std), + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), } } } diff --git a/src/io/stdin.rs b/src/io/stdin.rs index 325b875..c9578f1 100644 --- a/src/io/stdin.rs +++ b/src/io/stdin.rs @@ -1,5 +1,5 @@ use crate::io::blocking::Blocking; -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -63,16 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin { } impl AsyncRead for Stdin { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut self.std).poll_read(cx, buf) } } diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs new file mode 100644 index 0000000..d21c842 --- /dev/null +++ b/src/io/stdio_common.rs @@ -0,0 +1,220 @@ +//! Contains utilities for stdout and stderr. +use crate::io::AsyncWrite; +use std::pin::Pin; +use std::task::{Context, Poll}; +/// # Windows +/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest, +/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF. +/// That's why, wrapped writer will always receive well-formed utf-8 bytes. +/// # Other platforms +/// passes data to `inner` as is +#[derive(Debug)] +pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { + inner: W, +} + +impl<W> SplitByUtf8BoundaryIfWindows<W> { + pub(crate) fn new(inner: W) -> Self { + Self { inner } + } +} + +// this constant is defined by Unicode standard. +const MAX_BYTES_PER_CHAR: usize = 4; + +// Subject for tweaking here +const MAGIC_CONST: usize = 8; + +impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + // just a closure to avoid repetitive code + let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); + + // 1. Only windows stdio can suffer from non-utf8. + // We also check for `test` so that we can write some tests + // for further code. Since `AsyncWrite` can always shrink + // buffer at its discretion, excessive (i.e. in tests) shrinking + // does not break correctness. + // 2. If buffer is small, it will not be shrinked. + // That's why, it's "textness" will not change, so we don't have + // to fixup it. + if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF + { + return call_inner(buf); + } + + buf = &buf[..crate::io::blocking::MAX_BUF]; + + // Now there are two possibilites. + // If caller gave is binary buffer, we **should not** shrink it + // anymore, because excessive shrinking hits performance. + // If caller gave as binary buffer, we **must** additionaly + // shrink it to strip incomplete char at the end of buffer. + // that's why check we will perform now is allowed to have + // false-positive. + + // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. + // if they are (possibly incomplete) utf8, then we can be quite sure + // that input buffer was utf8. + + let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { + Ok(_) => true, + Err(err) => { + let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); + incomplete_bytes < MAX_BYTES_PER_CHAR + } + }; + + if have_to_fix_up { + // We must pop several bytes at the end which form incomplete + // character. To achieve it, we exploit UTF8 encoding: + // for any code point, all bytes except first start with 0b10 prefix. + // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details + let trailing_incomplete_char_size = buf + .iter() + .rev() + .take(MAX_BYTES_PER_CHAR) + .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) + .unwrap_or(0) + + 1; + buf = &buf[..buf.len() - trailing_incomplete_char_size]; + } + + call_inner(buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +#[cfg(not(loom))] +mod tests { + use crate::io::AsyncWriteExt; + use std::io; + use std::pin::Pin; + use std::task::Context; + use std::task::Poll; + + const MAX_BUF: usize = 16 * 1024; + + struct TextMockWriter; + + impl crate::io::AsyncWrite for TextMockWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + assert!(std::str::from_utf8(buf).is_ok()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + struct LoggingMockWriter { + write_history: Vec<usize>, + } + + impl LoggingMockWriter { + fn new() -> Self { + LoggingMockWriter { + write_history: Vec::new(), + } + } + } + + impl crate::io::AsyncWrite for LoggingMockWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + self.write_history.push(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + #[test] + fn test_splitter() { + let data = str::repeat("█", MAX_BUF); + let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); + let fut = async move { + wr.write_all(data.as_bytes()).await.unwrap(); + }; + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(fut); + } + + #[test] + fn test_pseudo_text() { + // In this test we write a piece of binary data, whose beginning is + // text though. We then validate that even in this corner case buffer + // was not shrinked too much. + let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; + let mut data: Vec<u8> = str::repeat("a", checked_count).into(); + data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); + let mut writer = LoggingMockWriter::new(); + let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(async { + splitter.write_all(&data).await.unwrap(); + }); + // Check that at most two writes were performed + assert!(writer.write_history.len() <= 2); + // Check that all has been written + assert_eq!( + writer.write_history.iter().copied().sum::<usize>(), + data.len() + ); + // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked + // from the buffer: one because it was outside of MAX_BUF boundary, and + // up to one "utf8 code point". + assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); + } +} diff --git a/src/io/stdout.rs b/src/io/stdout.rs index 5377993..a08ed01 100644 --- a/src/io/stdout.rs +++ b/src/io/stdout.rs @@ -1,6 +1,6 @@ use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; use crate::io::AsyncWrite; - use std::io; use std::pin::Pin; use std::task::Context; @@ -35,7 +35,7 @@ cfg_io_std! { /// ``` #[derive(Debug)] pub struct Stdout { - std: Blocking<std::io::Stdout>, + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stdout>>, } /// Constructs a new handle to the standard output of the current process. @@ -67,7 +67,7 @@ cfg_io_std! { pub fn stdout() -> Stdout { let std = io::stdout(); Stdout { - std: Blocking::new(std), + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), } } } diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 1bfab90..9e87f2f 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -14,7 +14,7 @@ cfg_io_util! { /// Equivalent to: /// /// ```ignore - /// async fn read_until(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>; + /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>; /// ``` /// /// This function will read bytes from the underlying stream until the diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index e848a5d..0ab66c2 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -986,10 +986,12 @@ cfg_io_util! { /// /// All bytes read from this source will be appended to the specified /// buffer `buf`. This function will continuously call [`read()`] to - /// append more data to `buf` until [`read()`][read] returns `Ok(0)`. + /// append more data to `buf` until [`read()`] returns `Ok(0)`. /// /// If successful, the total number of bytes read is returned. /// + /// [`read()`]: AsyncReadExt::read + /// /// # Errors /// /// If a read error is encountered then the `read_to_end` operation @@ -1018,7 +1020,7 @@ cfg_io_util! { /// (See also the [`tokio::fs::read`] convenience function for reading from a /// file.) /// - /// [`tokio::fs::read`]: crate::fs::read::read + /// [`tokio::fs::read`]: fn@crate::fs::read fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self> where Self: Unpin, @@ -1065,7 +1067,7 @@ cfg_io_util! { /// (See also the [`crate::fs::read_to_string`] convenience function for /// reading from a file.) /// - /// [`crate::fs::read_to_string`]: crate::fs::read_to_string::read_to_string + /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self> where Self: Unpin, @@ -1078,7 +1080,11 @@ cfg_io_util! { /// This function returns a new instance of `AsyncRead` which will read /// at most `limit` bytes, after which it will always return EOF /// (`Ok(0)`). Any read errors will not count towards the number of - /// bytes read and future calls to [`read()`][read] may succeed. + /// bytes read and future calls to [`read()`] may succeed. + /// + /// [`read()`]: fn@crate::io::AsyncReadExt::read + /// + /// [read]: AsyncReadExt::read /// /// # Examples /// diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index c7a0f72..351900b 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -2,65 +2,73 @@ use crate::io::seek::{seek, Seek}; use crate::io::AsyncSeek; use std::io::SeekFrom; -/// An extension trait which adds utility methods to [`AsyncSeek`] types. -/// -/// As a convenience, this trait may be imported using the [`prelude`]: -/// -/// # Examples -/// -/// ``` -/// use std::io::{Cursor, SeekFrom}; -/// use tokio::prelude::*; -/// -/// #[tokio::main] -/// async fn main() -> io::Result<()> { -/// let mut cursor = Cursor::new(b"abcdefg"); -/// -/// // the `seek` method is defined by this trait -/// cursor.seek(SeekFrom::Start(3)).await?; -/// -/// let mut buf = [0; 1]; -/// let n = cursor.read(&mut buf).await?; -/// assert_eq!(n, 1); -/// assert_eq!(buf, [b'd']); -/// -/// Ok(()) -/// } -/// ``` -/// -/// See [module][crate::io] documentation for more details. -/// -/// [`AsyncSeek`]: AsyncSeek -/// [`prelude`]: crate::prelude -pub trait AsyncSeekExt: AsyncSeek { - /// Creates a future which will seek an IO object, and then yield the - /// new position in the object and the object itself. +cfg_io_util! { + /// An extension trait which adds utility methods to [`AsyncSeek`] types. /// - /// In the case of an error the buffer and the object will be discarded, with - /// the error yielded. + /// As a convenience, this trait may be imported using the [`prelude`]: /// /// # Examples /// - /// ```no_run - /// use tokio::fs::File; + /// ``` + /// use std::io::{Cursor, SeekFrom}; /// use tokio::prelude::*; /// - /// use std::io::SeekFrom; + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut cursor = Cursor::new(b"abcdefg"); + /// + /// // the `seek` method is defined by this trait + /// cursor.seek(SeekFrom::Start(3)).await?; /// - /// # async fn dox() -> std::io::Result<()> { - /// let mut file = File::open("foo.txt").await?; - /// file.seek(SeekFrom::Start(6)).await?; + /// let mut buf = [0; 1]; + /// let n = cursor.read(&mut buf).await?; + /// assert_eq!(n, 1); + /// assert_eq!(buf, [b'd']); /// - /// let mut contents = vec![0u8; 10]; - /// file.read_exact(&mut contents).await?; - /// # Ok(()) - /// # } + /// Ok(()) + /// } /// ``` - fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> - where - Self: Unpin, - { - seek(self, pos) + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncSeek`]: AsyncSeek + /// [`prelude`]: crate::prelude + pub trait AsyncSeekExt: AsyncSeek { + /// Creates a future which will seek an IO object, and then yield the + /// new position in the object and the object itself. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>; + /// ``` + /// + /// In the case of an error the buffer and the object will be discarded, with + /// the error yielded. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::prelude::*; + /// + /// use std::io::SeekFrom; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// file.seek(SeekFrom::Start(6)).await?; + /// + /// let mut contents = vec![0u8; 10]; + /// file.read_exact(&mut contents).await?; + /// # Ok(()) + /// # } + /// ``` + fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> + where + Self: Unpin, + { + seek(self, pos) + } } } diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index fa41097..e6ef5b2 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -119,6 +119,7 @@ cfg_io_util! { write(self, src) } + /// Writes a buffer into this writer, advancing the buffer's internal /// cursor. /// @@ -134,7 +135,7 @@ cfg_io_util! { /// internal cursor is advanced by the number of bytes written. A /// subsequent call to `write_buf` using the **same** `buf` value will /// resume from the point that the first call to `write_buf` completed. - /// A call to `write` represents *at most one* attempt to write to any + /// A call to `write_buf` represents *at most one* attempt to write to any /// wrapped object. /// /// # Return @@ -976,6 +977,8 @@ cfg_io_util! { /// no longer attempt to write to the stream. For example, the /// `TcpStream` implementation will issue a `shutdown(Write)` sys call. /// + /// [`flush`]: fn@crate::io::AsyncWriteExt::flush + /// /// # Examples /// /// ```no_run diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index a1c5990..271f61b 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -1,10 +1,8 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -use bytes::Buf; use pin_project_lite::pin_project; -use std::io::{self, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt}; @@ -44,21 +42,12 @@ impl<R: AsyncRead> BufReader<R> { /// Creates a new `BufReader` with the specified buffer capacity. pub fn with_capacity(capacity: usize, inner: R) -> Self { - unsafe { - let mut buffer = Vec::with_capacity(capacity); - buffer.set_len(capacity); - - { - // Convert to MaybeUninit - let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]); - inner.prepare_uninitialized_buffer(b); - } - Self { - inner, - buf: buffer.into_boxed_slice(), - pos: 0, - cap: 0, - } + let buffer = vec![0; capacity]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, } } @@ -110,25 +99,21 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.len() >= self.buf.len() { + if self.pos == self.cap && buf.remaining() >= self.buf.len() { let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); self.discard_buffer(); return Poll::Ready(res); } - let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?; - let nread = rem.read(buf)?; - self.consume(nread); - Poll::Ready(Ok(nread)) - } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) } } @@ -142,7 +127,9 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> { // to tell the compiler that the pos..cap slice is always valid. if *me.pos >= *me.cap { debug_assert!(*me.pos == *me.cap); - *me.cap = ready!(me.inner.poll_read(cx, me.buf))?; + let mut buf = ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + *me.cap = buf.filled().len(); *me.pos = 0; } Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) @@ -163,14 +150,6 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { self.get_pin_mut().poll_write(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.get_pin_mut().poll_write_buf(cx, buf) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.get_pin_mut().poll_flush(cx) } diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index a56a451..cc857e2 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -1,9 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -137,15 +136,10 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.project().inner.poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } } impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs index efd053e..5e3d4b7 100644 --- a/src/io/util/buf_writer.rs +++ b/src/io/util/buf_writer.rs @@ -1,10 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; use std::io::{self, Write}; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -147,15 +146,10 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.get_pin_mut().poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.get_ref().prepare_uninitialized_buffer(buf) - } } impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { diff --git a/src/io/util/chain.rs b/src/io/util/chain.rs index 8ba9194..84f37fc 100644 --- a/src/io/util/chain.rs +++ b/src/io/util/chain.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; @@ -84,26 +84,20 @@ where T: AsyncRead, U: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - if self.first.prepare_uninitialized_buffer(buf) { - return true; - } - if self.second.prepare_uninitialized_buffer(buf) { - return true; - } - false - } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { let me = self.project(); if !*me.done_first { - match ready!(me.first.poll_read(cx, buf)?) { - 0 if !buf.is_empty() => *me.done_first = true, - n => return Poll::Ready(Ok(n)), + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); } } me.second.poll_read(cx, buf) diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index 7bfe296..c5981cf 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -1,30 +1,25 @@ -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { - /// A future that asynchronously copies the entire contents of a reader into a - /// writer. - /// - /// This struct is generally created by calling [`copy`][copy]. Please - /// see the documentation of `copy()` for more details. - /// - /// [copy]: copy() - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct Copy<'a, R: ?Sized, W: ?Sized> { - reader: &'a mut R, - read_done: bool, - writer: &'a mut W, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, - } +/// A future that asynchronously copies the entire contents of a reader into a +/// writer. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +struct Copy<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + read_done: bool, + writer: &'a mut W, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +cfg_io_util! { /// Asynchronously copies the entire contents of a reader into a writer. /// /// This function returns a future that will continuously read data from @@ -58,7 +53,7 @@ cfg_io_util! { /// # Ok(()) /// # } /// ``` - pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W> + pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> where R: AsyncRead + Unpin + ?Sized, W: AsyncWrite + Unpin + ?Sized, @@ -71,7 +66,7 @@ cfg_io_util! { pos: 0, cap: 0, buf: vec![0; 2048].into_boxed_slice(), - } + }.await } } @@ -88,7 +83,9 @@ where // continue. if self.pos == self.cap && !self.read_done { let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + let mut buf = ReadBuf::new(&mut me.buf); + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; + let n = buf.filled().len(); if n == 0 { self.read_done = true; } else { @@ -122,14 +119,3 @@ where } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>(); - } -} diff --git a/src/io/util/copy_buf.rs b/src/io/util/copy_buf.rs new file mode 100644 index 0000000..6831580 --- /dev/null +++ b/src/io/util/copy_buf.rs @@ -0,0 +1,102 @@ +use crate::io::{AsyncBufRead, AsyncWrite}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// A future that asynchronously copies the entire contents of a reader into a + /// writer. + /// + /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please + /// see the documentation of `copy_buf()` for more details. + /// + /// [copy_buf]: copy_buf() + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + struct CopyBuf<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + amt: u64, + } + + /// Asynchronously copies the entire contents of a reader into a writer. + /// + /// This function returns a future that will continuously read data from + /// `reader` and then write it into `writer` in a streaming fashion until + /// `reader` returns EOF. + /// + /// On success, the total number of bytes that were copied from `reader` to + /// `writer` is returned. + /// + /// + /// # Errors + /// + /// The returned future will finish with an error will return an error + /// immediately if any call to `poll_fill_buf` or `poll_write` returns an + /// error. + /// + /// # Examples + /// + /// ``` + /// use tokio::io; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut reader: &[u8] = b"hello"; + /// let mut writer: Vec<u8> = vec![]; + /// + /// io::copy_buf(&mut reader, &mut writer).await?; + /// + /// assert_eq!(b"hello", &writer[..]); + /// # Ok(()) + /// # } + /// ``` + pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> + where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, + { + CopyBuf { + reader, + writer, + amt: 0, + }.await + } +} + +impl<R, W> Future for CopyBuf<'_, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<u64>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let me = &mut *self; + let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?; + if buffer.is_empty() { + ready!(Pin::new(&mut self.writer).poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + + let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?; + if i == 0 { + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + self.amt += i as u64; + Pin::new(&mut *self.reader).consume(i); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + use std::marker::PhantomPinned; + crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>(); + } +} diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs index 576058d..f964d18 100644 --- a/src/io/util/empty.rs +++ b/src/io/util/empty.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use std::fmt; use std::io; @@ -47,16 +47,13 @@ cfg_io_util! { } impl AsyncRead for Empty { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - _: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(Ok(0)) + _: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) } } diff --git a/src/io/util/flush.rs b/src/io/util/flush.rs index 534a516..88d60b8 100644 --- a/src/io/util/flush.rs +++ b/src/io/util/flush.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to fully flush an I/O object. /// /// Created by the [`AsyncWriteExt::flush`][flush] function. /// [flush]: crate::io::AsyncWriteExt::flush #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Flush<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Flush { a } + Flush { + a, + _pin: PhantomPinned, + } } impl<A> Future for Flush<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); Pin::new(&mut *me.a).poll_flush(cx) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Flush<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index ee27400..b41f04a 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -83,8 +83,7 @@ impl<R> Lines<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_line( + fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<String>>> { diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs new file mode 100644 index 0000000..e91a932 --- /dev/null +++ b/src/io/util/mem.rs @@ -0,0 +1,223 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct DuplexStream { + read: Arc<Mutex<Pipe>>, + write: Arc<Mutex<Pipe>>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option<Waker>, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option<Waker>, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.read.lock()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.write.lock()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().close(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close(&mut self) { + self.is_closed = true; + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.remaining()); + buf.put_slice(&self.buffer[..max]); + self.buffer.advance(max); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(())) + } else if self.is_closed { + Poll::Ready(Ok(())) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + self.close(); + Poll::Ready(Ok(())) + } +} diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs index c4754ab..e75ea03 100644 --- a/src/io/util/mod.rs +++ b/src/io/util/mod.rs @@ -25,7 +25,10 @@ cfg_io_util! { mod chain; mod copy; - pub use copy::{copy, Copy}; + pub use copy::copy; + + mod copy_buf; + pub use copy_buf::copy_buf; mod empty; pub use empty::{empty, Empty}; @@ -35,6 +38,9 @@ cfg_io_util! { mod lines; pub use lines::Lines; + mod mem; + pub use mem::{duplex, DuplexStream}; + mod read; mod read_buf; mod read_exact; @@ -60,11 +66,6 @@ cfg_io_util! { mod split; pub use split::Split; - cfg_stream! { - mod stream_reader; - pub use stream_reader::{stream_reader, StreamReader}; - } - mod take; pub use take::Take; diff --git a/src/io/util/read.rs b/src/io/util/read.rs index a8ca370..edc9d5a 100644 --- a/src/io/util/read.rs +++ b/src/io/util/read.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,10 +17,14 @@ pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R> where R: AsyncRead + Unpin + ?Sized, { - Read { reader, buf } + Read { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// A future which can be used to easily read available number of bytes to fill /// a buffer. /// @@ -28,6 +34,9 @@ cfg_io_util! { pub struct Read<'a, R: ?Sized> { reader: &'a mut R, buf: &'a mut [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -37,19 +46,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read(cx, me.buf) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Read<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + let mut buf = ReadBuf::new(*me.buf); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?; + Poll::Ready(Ok(buf.filled().len())) } } diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs index 6ee3d24..696deef 100644 --- a/src/io/util/read_buf.rs +++ b/src/io/util/read_buf.rs @@ -1,8 +1,10 @@ use crate::io::AsyncRead; use bytes::BufMut; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; @@ -11,16 +13,22 @@ where R: AsyncRead + Unpin, B: BufMut, { - ReadBuf { reader, buf } + ReadBuf { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadBuf<'a, R, B> { reader: &'a mut R, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -31,8 +39,34 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + use crate::io::ReadBuf; + use std::mem::MaybeUninit; + + let me = self.project(); + + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = me.buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs index 86b8412..1e8150e 100644 --- a/src/io/util/read_exact.rs +++ b/src/io/util/read_exact.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -17,12 +19,12 @@ where { ReadExact { reader, - buf, - pos: 0, + buf: ReadBuf::new(buf), + _pin: PhantomPinned, } } -cfg_io_util! { +pin_project! { /// Creates a future which will read exactly enough bytes to fill `buf`, /// returning an error if EOF is hit sooner. /// @@ -31,8 +33,10 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadExact<'a, A: ?Sized> { reader: &'a mut A, - buf: &'a mut [u8], - pos: usize, + buf: ReadBuf<'a>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -46,32 +50,20 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let mut me = self.project(); + loop { // if our buffer is empty, then we need to read some data to continue. - if self.pos < self.buf.len() { - let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?; - me.pos += n; - if n == 0 { + let rem = me.buf.remaining(); + if rem != 0 { + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + if me.buf.remaining() == rem { return Err(eof()).into(); } - } - - if self.pos >= self.buf.len() { - return Poll::Ready(Ok(self.pos)); + } else { + return Poll::Ready(Ok(me.buf.capacity())); } } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadExact<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs index 9d37dc7..5b9fb7b 100644 --- a/src/io/util/read_int.rs +++ b/src/io/util/read_int.rs @@ -1,10 +1,11 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use bytes::Buf; use pin_project_lite::pin_project; use std::future::Future; use std::io; use std::io::ErrorKind::UnexpectedEof; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -16,11 +17,15 @@ macro_rules! reader { ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] src: R, buf: [u8; $bytes], read: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -30,6 +35,7 @@ macro_rules! reader { src, buf: [0; $bytes], read: 0, + _pin: PhantomPinned, } } } @@ -48,17 +54,19 @@ macro_rules! reader { } while *me.read < $bytes as u8 { - *me.read += match me - .src - .as_mut() - .poll_read(cx, &mut me.buf[*me.read as usize..]) - { + let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]); + + *me.read += match me.src.as_mut().poll_read(cx, &mut buf) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => { - return Poll::Ready(Err(UnexpectedEof.into())); + Poll::Ready(Ok(())) => { + let n = buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + n as u8 } - Poll::Ready(Ok(n)) => n as u8, }; } @@ -75,15 +83,22 @@ macro_rules! reader8 { pin_project! { /// Future returned from `read_u8` #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] reader: R, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<R> $name<R> { pub(crate) fn new(reader: R) -> $name<R> { - $name { reader } + $name { + reader, + _pin: PhantomPinned, + } } } @@ -97,12 +112,17 @@ macro_rules! reader8 { let me = self.project(); let mut buf = [0; 1]; - match me.reader.poll_read(cx, &mut buf[..]) { + let mut buf = ReadBuf::new(&mut buf); + match me.reader.poll_read(cx, &mut buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())), - Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)), - Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Ok(())) => { + if buf.filled().len() == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + Poll::Ready(Ok(buf.filled()[0] as $ty)) + } } } } diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs index d625a76..d38ffaf 100644 --- a/src/io/util/read_line.rs +++ b/src/io/util/read_line.rs @@ -1,26 +1,32 @@ use crate::io::util::read_until::read_until_internal; use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; +use std::string::FromUtf8Error; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - /// This is the buffer we were provided. It will be replaced with an empty string - /// while reading to postpone utf-8 handling until after reading. + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. output: &'a mut String, - /// The actual allocation of the string is moved into a vector instead. + // The actual allocation of the string is moved into this vector instead. buf: Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -33,6 +39,7 @@ where buf: mem::replace(string, String::new()).into_bytes(), output: string, read: 0, + _pin: PhantomPinned, } } @@ -42,31 +49,33 @@ fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_re *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); } -pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( - reader: Pin<&mut R>, - cx: &mut Context<'_>, +/// This handles the various failure cases and puts the string back into `output`. +/// +/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line` +/// disagree on what should happen when an IO error occurs. +pub(super) fn finish_string_read( + io_res: io::Result<usize>, + utf8_res: Result<String, FromUtf8Error>, + read: usize, output: &mut String, - buf: &mut Vec<u8>, - read: &mut usize, + truncate_on_io_error: bool, ) -> Poll<io::Result<usize>> { - let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); - let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - - // At this point both buf and output are empty. The allocation is in utf8_res. - - debug_assert!(buf.is_empty()); match (io_res, utf8_res) { (Ok(num_bytes), Ok(string)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); *output = string; Poll::Ready(Ok(num_bytes)) } (Err(io_err), Ok(string)) => { *output = string; + if truncate_on_io_error { + let original_len = output.len() - read; + output.truncate(original_len); + } Poll::Ready(Err(io_err)) } (Ok(num_bytes), Err(utf8_err)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); put_back_original_data(output, utf8_err.into_bytes(), num_bytes); Poll::Ready(Err(io::Error::new( @@ -75,35 +84,36 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( ))) } (Err(io_err), Err(utf8_err)) => { - put_back_original_data(output, utf8_err.into_bytes(), *read); + put_back_original_data(output, utf8_err.into_bytes(), read); Poll::Ready(Err(io_err)) } } } -impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { - type Output = io::Result<usize>; +pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( + reader: Pin<&mut R>, + cx: &mut Context<'_>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, +) -> Poll<io::Result<usize>> { + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - output, - buf, - read, - } = &mut *self; + // At this point both buf and output are empty. The allocation is in utf8_res. - read_line_internal(Pin::new(reader), cx, output, buf, read) - } + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, false) } -#[cfg(test)] -mod tests { - use super::*; +impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadLine<'_, PhantomPinned>>(); + read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read) } } diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs index a2cd99b..a974625 100644 --- a/src/io/util/read_to_end.rs +++ b/src/io/util/read_to_end.rs @@ -1,92 +1,105 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; -use std::mem::MaybeUninit; +use std::marker::PhantomPinned; +use std::mem::{self, MaybeUninit}; use std::pin::Pin; use std::task::{Context, Poll}; -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub struct ReadToEnd<'a, R: ?Sized> { - reader: &'a mut R, - buf: &'a mut Vec<u8>, - start_len: usize, +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadToEnd<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } -pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, R> +pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R> where R: AsyncRead + Unpin + ?Sized, { - let start_len = buf.len(); ReadToEnd { reader, - buf, - start_len, + buf: buffer, + read: 0, + _pin: PhantomPinned, } } -struct Guard<'a> { - buf: &'a mut Vec<u8>, - len: usize, -} - -impl Drop for Guard<'_> { - fn drop(&mut self) { - unsafe { - self.buf.set_len(self.len); +pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( + buf: &mut Vec<u8>, + mut reader: Pin<&mut R>, + num_read: &mut usize, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + loop { + // safety: The caller promised to prepare the buffer. + let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx)); + match ret { + Err(err) => return Poll::Ready(Err(err)), + Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))), + Ok(num) => { + *num_read += num; + } } } } -// This uses an adaptive system to extend the vector when it fills. We want to -// avoid paying to allocate and zero a huge chunk of memory if the reader only -// has 4 bytes while still making large reads if the reader does have a ton -// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every -// time is 4,500 times (!) slower than this if the reader has a very small -// amount of data to return. -// -// Because we're extending the buffer with uninitialized data for trusted -// readers, we need to make sure to truncate that if any of this panics. -pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( - mut rd: Pin<&mut R>, - cx: &mut Context<'_>, +/// Tries to read from the provided AsyncRead. +/// +/// The length of the buffer is increased by the number of bytes read. +fn poll_read_to_end<R: AsyncRead + ?Sized>( buf: &mut Vec<u8>, - start_len: usize, + read: Pin<&mut R>, + cx: &mut Context<'_>, ) -> Poll<io::Result<usize>> { - let mut g = Guard { - len: buf.len(), - buf, - }; - let ret; - loop { - if g.len == g.buf.len() { - unsafe { - g.buf.reserve(32); - let capacity = g.buf.capacity(); - g.buf.set_len(capacity); + // This uses an adaptive system to extend the vector when it fills. We want to + // avoid paying to allocate and zero a huge chunk of memory if the reader only + // has 4 bytes while still making large reads if the reader does have a ton + // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every + // time is 4,500 times (!) slower than this if the reader has a very small + // amount of data to return. + reserve(buf, 32); - let b = &mut *(&mut g.buf[g.len..] as *mut [u8] as *mut [MaybeUninit<u8>]); + let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); - rd.prepare_uninitialized_buffer(b); - } - } + ready!(read.poll_read(cx, &mut unused_capacity))?; - match ready!(rd.as_mut().poll_read(cx, &mut g.buf[g.len..])) { - Ok(0) => { - ret = Poll::Ready(Ok(g.len - start_len)); - break; - } - Ok(n) => g.len += n, - Err(e) => { - ret = Poll::Ready(Err(e)); - break; - } - } + let n = unused_capacity.filled().len(); + let new_len = buf.len() + n; + + // This should no longer even be possible in safe Rust. An implementor + // would need to have unsafely *replaced* the buffer inside `ReadBuf`, + // which... yolo? + assert!(new_len <= buf.capacity()); + unsafe { + buf.set_len(new_len); } + Poll::Ready(Ok(n)) +} - ret +/// Allocates more memory and ensures that the unused capacity is prepared for use +/// with the `AsyncRead`. +fn reserve(buf: &mut Vec<u8>, bytes: usize) { + if buf.capacity() - buf.len() >= bytes { + return; + } + buf.reserve(bytes); +} + +/// Returns the unused capacity of the provided vector. +fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] { + let uninit = bytes::BufMut::bytes_mut(buf); + unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) } } impl<A> Future for ReadToEnd<'_, A> @@ -95,19 +108,9 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let this = &mut *self; - read_to_end_internal(Pin::new(&mut this.reader), cx, this.buf, this.start_len) - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToEnd<'_, PhantomPinned>>(); + read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx) } } diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs index cab0505..e463203 100644 --- a/src/io/util/read_to_string.rs +++ b/src/io/util/read_to_string.rs @@ -1,58 +1,71 @@ +use crate::io::util::read_line::finish_string_read; use crate::io::util::read_to_end::read_to_end_internal; use crate::io::AsyncRead; +use pin_project_lite::pin_project; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; -cfg_io_util! { +pin_project! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadToString<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec<u8>, - start_len: usize, + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + // The actual allocation of the string is moved into this vector instead. + buf: Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } -pub(crate) fn read_to_string<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadToString<'a, R> +pub(crate) fn read_to_string<'a, R>( + reader: &'a mut R, + string: &'a mut String, +) -> ReadToString<'a, R> where R: AsyncRead + ?Sized + Unpin, { - let start_len = buf.len(); + let buf = mem::replace(string, String::new()).into_bytes(); ReadToString { reader, - bytes: mem::replace(buf, String::new()).into_bytes(), buf, - start_len, + output: string, + read: 0, + _pin: PhantomPinned, } } -fn read_to_string_internal<R: AsyncRead + ?Sized>( +/// # Safety +/// +/// Before first calling this method, the unused capacity must have been +/// prepared for use with the provided AsyncRead. This can be done using the +/// `prepare_buffer` function in `read_to_end.rs`. +unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>( reader: Pin<&mut R>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec<u8>, - start_len: usize, ) -> Poll<io::Result<usize>> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; - match String::from_utf8(mem::replace(bytes, Vec::new())) { - Ok(string) => { - debug_assert!(buf.is_empty()); - *buf = string; - Poll::Ready(Ok(ret)) - } - Err(e) => { - *bytes = e.into_bytes(); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - "stream did not contain valid UTF-8", - ))) - } - } + let io_res = ready!(read_to_end_internal(buf, reader, read, cx)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, true) } impl<A> Future for ReadToString<'_, A> @@ -61,31 +74,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - buf, - bytes, - start_len, - } = &mut *self; - let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); - if let Poll::Ready(Err(_)) = ret { - // Put back the original string. - bytes.truncate(*start_len); - **buf = String::from_utf8(mem::replace(bytes, Vec::new())) - .expect("original string no longer utf-8"); - } - ret - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToString<'_, PhantomPinned>>(); + // safety: The constructor of ReadToString called `prepare_buffer`. + unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) } } } diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs index 78dac8c..3599cff 100644 --- a/src/io/util/read_until.rs +++ b/src/io/util/read_until.rs @@ -1,12 +1,14 @@ use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. /// The delimeter is included in the resulting vector. #[derive(Debug)] @@ -15,9 +17,12 @@ cfg_io_util! { reader: &'a mut R, delimeter: u8, buf: &'a mut Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -34,6 +39,7 @@ where delimeter, buf, read: 0, + _pin: PhantomPinned, } } @@ -66,24 +72,8 @@ pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - delimeter, - buf, - read, - } = &mut *self; - read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadUntil<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read) } } diff --git a/src/io/util/repeat.rs b/src/io/util/repeat.rs index eeef7cc..1142765 100644 --- a/src/io/util/repeat.rs +++ b/src/io/util/repeat.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -47,19 +47,17 @@ cfg_io_util! { } impl AsyncRead for Repeat { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - for byte in &mut *buf { - *byte = self.byte; + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // TODO: could be faster, but should we unsafe it? + while buf.remaining() != 0 { + buf.put_slice(&[self.byte]); } - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/shutdown.rs b/src/io/util/shutdown.rs index 33ac0ac..6d30b00 100644 --- a/src/io/util/shutdown.rs +++ b/src/io/util/shutdown.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to shutdown an I/O object. /// /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. /// [shutdown]: crate::io::AsyncWriteExt::shutdown + #[must_use = "futures do nothing unless you `.await` or poll them"] #[derive(Debug)] pub struct Shutdown<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Shutdown { a } + Shutdown { + a, + _pin: PhantomPinned, + } } impl<A> Future for Shutdown<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; - Pin::new(&mut *me.a).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Shutdown<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(me.a).poll_shutdown(cx) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index f552ed5..492e26a 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -65,8 +65,7 @@ impl<R> Split<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_segment( + fn poll_next_segment( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<Vec<u8>>>> { diff --git a/src/io/util/stream_reader.rs b/src/io/util/stream_reader.rs deleted file mode 100644 index b98f8bd..0000000 --- a/src/io/util/stream_reader.rs +++ /dev/null @@ -1,184 +0,0 @@ -use crate::io::{AsyncBufRead, AsyncRead}; -use crate::stream::Stream; -use bytes::{Buf, BufMut}; -use pin_project_lite::pin_project; -use std::io; -use std::mem::MaybeUninit; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pin_project! { - /// Convert a stream of byte chunks into an [`AsyncRead`]. - /// - /// This type is usually created using the [`stream_reader`] function. - /// - /// [`AsyncRead`]: crate::io::AsyncRead - /// [`stream_reader`]: crate::io::stream_reader - #[derive(Debug)] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] - #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] - pub struct StreamReader<S, B> { - #[pin] - inner: S, - chunk: Option<B>, - } -} - -/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead). -/// -/// # Example -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::{stream_reader, AsyncReadExt}; -/// # #[tokio::main] -/// # async fn main() -> std::io::Result<()> { -/// -/// // Create a stream from an iterator. -/// let stream = tokio::stream::iter(vec![ -/// Ok(Bytes::from_static(&[0, 1, 2, 3])), -/// Ok(Bytes::from_static(&[4, 5, 6, 7])), -/// Ok(Bytes::from_static(&[8, 9, 10, 11])), -/// ]); -/// -/// // Convert it to an AsyncRead. -/// let mut read = stream_reader(stream); -/// -/// // Read five bytes from the stream. -/// let mut buf = [0; 5]; -/// read.read_exact(&mut buf).await?; -/// assert_eq!(buf, [0, 1, 2, 3, 4]); -/// -/// // Read the rest of the current chunk. -/// assert_eq!(read.read(&mut buf).await?, 3); -/// assert_eq!(&buf[..3], [5, 6, 7]); -/// -/// // Read the next chunk. -/// assert_eq!(read.read(&mut buf).await?, 4); -/// assert_eq!(&buf[..4], [8, 9, 10, 11]); -/// -/// // We have now reached the end. -/// assert_eq!(read.read(&mut buf).await?, 0); -/// -/// # Ok(()) -/// # } -/// ``` -#[cfg_attr(docsrs, doc(cfg(feature = "stream")))] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub fn stream_reader<S, B>(stream: S) -> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - StreamReader::new(stream) -} - -impl<S, B> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - /// Convert the provided stream into an `AsyncRead`. - fn new(stream: S) -> Self { - Self { - inner: stream, - chunk: None, - } - } - /// Do we have a chunk and is it non-empty? - fn has_chunk(self: Pin<&mut Self>) -> bool { - if let Some(chunk) = self.project().chunk { - chunk.remaining() > 0 - } else { - false - } - } -} - -impl<S, B> AsyncRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.len()); - (&mut buf[..len]).copy_from_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - fn poll_read_buf<BM: BufMut>( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut BM, - ) -> Poll<io::Result<usize>> - where - Self: Sized, - { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.remaining_mut()); - buf.put_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } -} - -impl<S, B> AsyncBufRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { - loop { - if self.as_mut().has_chunk() { - // This unwrap is very sad, but it can't be avoided. - let buf = self.project().chunk.as_ref().unwrap().bytes(); - return Poll::Ready(Ok(buf)); - } else { - match self.as_mut().project().inner.poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - // Go around the loop in case the chunk is empty. - *self.as_mut().project().chunk = Some(chunk); - } - Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), - Poll::Ready(None) => return Poll::Ready(Ok(&[])), - Poll::Pending => return Poll::Pending, - } - } - } - } - fn consume(self: Pin<&mut Self>, amt: usize) { - if amt > 0 { - self.project() - .chunk - .as_mut() - .expect("No chunk present") - .advance(amt); - } - } -} diff --git a/src/io/util/take.rs b/src/io/util/take.rs index 5d6bd90..b5e90c9 100644 --- a/src/io/util/take.rs +++ b/src/io/util/take.rs @@ -1,7 +1,6 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, io}; @@ -76,24 +75,27 @@ impl<R: AsyncRead> Take<R> { } impl<R: AsyncRead> AsyncRead for Take<R> { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<usize, io::Error>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { if self.limit_ == 0 { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } let me = self.project(); - let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize; - let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?; + let mut b = buf.take(*me.limit_ as usize); + ready!(me.inner.poll_read(cx, &mut b))?; + let n = b.filled().len(); + + // We need to update the original ReadBuf + unsafe { + buf.assume_init(n); + } + buf.advance(n); *me.limit_ -= n as u64; - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/write.rs b/src/io/util/write.rs index 433a421..92169eb 100644 --- a/src/io/util/write.rs +++ b/src/io/util/write.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Write<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +26,11 @@ pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W> where W: AsyncWrite + Unpin + ?Sized, { - Write { writer, buf } + Write { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for Write<'_, W> @@ -30,8 +39,8 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); Pin::new(&mut *me.writer).poll_write(cx, me.buf) } } diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs index 898006c..e59d41e 100644 --- a/src/io/util/write_all.rs +++ b/src/io/util/write_all.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteAll<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -19,7 +24,11 @@ pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, where W: AsyncWrite + Unpin + ?Sized, { - WriteAll { writer, buf } + WriteAll { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for WriteAll<'_, W> @@ -28,13 +37,13 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); while !me.buf.is_empty() { - let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?; + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; { - let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n); - me.buf = rest; + let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n); + *me.buf = rest; } if n == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); @@ -44,14 +53,3 @@ where Poll::Ready(Ok(())) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<WriteAll<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs index cedfde6..1310e5c 100644 --- a/src/io/util/write_buf.rs +++ b/src/io/util/write_buf.rs @@ -1,18 +1,22 @@ use crate::io::AsyncWrite; use bytes::Buf; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteBuf<'a, W, B> { writer: &'a mut W, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -23,7 +27,11 @@ where W: AsyncWrite + Unpin, B: Buf, { - WriteBuf { writer, buf } + WriteBuf { + writer, + buf, + _pin: PhantomPinned, + } } impl<W, B> Future for WriteBuf<'_, W, B> @@ -33,8 +41,15 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + + if !me.buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?; + me.buf.advance(n); + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs index ee992de..13bc191 100644 --- a/src/io/util/write_int.rs +++ b/src/io/util/write_int.rs @@ -4,6 +4,7 @@ use bytes::BufMut; use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,20 +16,25 @@ macro_rules! writer { ($name:ident, $ty:ty, $writer:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, buf: [u8; $bytes], written: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(w: W, value: $ty) -> Self { - let mut writer = $name { + let mut writer = Self { buf: [0; $bytes], written: 0, dst: w, + _pin: PhantomPinned, }; BufMut::$writer(&mut &mut writer.buf[..], value); writer @@ -72,16 +78,24 @@ macro_rules! writer8 { ($name:ident, $ty:ty) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, byte: $ty, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(dst: W, byte: $ty) -> Self { - Self { dst, byte } + Self { + dst, + byte, + _pin: PhantomPinned, + } } } @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/tokio/0.2.22")] +#![doc(html_root_url = "https://docs.rs/tokio/0.3.1")] #![allow( clippy::cognitive_complexity, clippy::large_enum_variant, @@ -10,22 +10,21 @@ rust_2018_idioms, unreachable_pub )] -#![deny(intra_doc_link_resolution_failure)] +#![cfg_attr(docsrs, deny(broken_intra_doc_links))] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, feature(doc_alias))] -//! A runtime for writing reliable, asynchronous, and slim applications. +//! A runtime for writing reliable network applications without compromising speed. //! //! Tokio is an event-driven, non-blocking I/O platform for writing asynchronous //! applications with the Rust programming language. At a high level, it //! provides a few major components: //! //! * Tools for [working with asynchronous tasks][tasks], including -//! [synchronization primitives and channels][sync] and [timeouts, delays, and +//! [synchronization primitives and channels][sync] and [timeouts, sleeps, and //! intervals][time]. //! * APIs for [performing asynchronous I/O][io], including [TCP and UDP][net] sockets, //! [filesystem][fs] operations, and [process] and [signal] management. @@ -45,7 +44,7 @@ //! [signal]: crate::signal //! [fs]: crate::fs //! [runtime]: crate::runtime -//! [website]: https://tokio.rs/docs/overview/ +//! [website]: https://tokio.rs/tokio/tutorial //! //! # A Tour of Tokio //! @@ -58,59 +57,9 @@ //! enabling the `full` feature flag: //! //! ```toml -//! tokio = { version = "0.2", features = ["full"] } +//! tokio = { version = "0.3", features = ["full"] } //! ``` //! -//! ## Feature flags -//! -//! Tokio uses a set of [feature flags] to reduce the amount of compiled code. It -//! is possible to just enable certain features over others. By default, Tokio -//! does not enable any features but allows one to enable a subset for their use -//! case. Below is a list of the available feature flags. You may also notice -//! above each function, struct and trait there is listed one or more feature flags -//! that are required for that item to be used. If you are new to Tokio it is -//! recommended that you use the `full` feature flag which will enable all public APIs. -//! Beware though that this will pull in many extra dependencies that you may not -//! need. -//! -//! - `full`: Enables all Tokio public API features listed below. -//! - `rt-core`: Enables `tokio::spawn` and the basic (single-threaded) scheduler. -//! - `rt-threaded`: Enables the heavier, multi-threaded, work-stealing scheduler. -//! - `rt-util`: Enables non-scheduler utilities. -//! - `io-driver`: Enables the `mio` based IO driver. -//! - `io-util`: Enables the IO based `Ext` traits. -//! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types. -//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`. -//! - `tcp`: Enables all `tokio::net::tcp` types. -//! - `udp`: Enables all `tokio::net::udp` types. -//! - `uds`: Enables all `tokio::net::unix` types. -//! - `time`: Enables `tokio::time` types and allows the schedulers to enable -//! the built in timer. -//! - `process`: Enables `tokio::process` types. -//! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros. -//! - `sync`: Enables all `tokio::sync` types. -//! - `stream`: Enables optional `Stream` implementations for types within Tokio. -//! - `signal`: Enables all `tokio::signal` types. -//! - `fs`: Enables `tokio::fs` types. -//! - `dns`: Enables async `tokio::net::ToSocketAddrs`. -//! - `test-util`: Enables testing based infrastructure for the Tokio runtime. -//! - `blocking`: Enables `block_in_place` and `spawn_blocking`. -//! -//! _Note: `AsyncRead` and `AsyncWrite` traits do not require any features and are -//! always available._ -//! -//! ### Internal features -//! -//! These features do not expose any new API, but influence internal -//! implementation aspects of Tokio, and can pull in additional -//! dependencies. They are not included in `full`: -//! -//! - `parking_lot`: As a potential optimization, use the _parking_lot_ crate's -//! synchronization primitives internally. MSRV may increase according to the -//! _parking_lot_ release in use. -//! -//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section -//! //! ### Authoring applications //! //! Tokio is great for writing applications and most users in this case shouldn't @@ -123,7 +72,7 @@ //! This example shows the quickest way to get started with Tokio. //! //! ```toml -//! tokio = { version = "0.2", features = ["full"] } +//! tokio = { version = "0.3", features = ["full"] } //! ``` //! //! ### Authoring libraries @@ -139,7 +88,7 @@ //! needs to `tokio::spawn` and use a `TcpStream`. //! //! ```toml -//! tokio = { version = "0.2", features = ["rt-core", "tcp"] } +//! tokio = { version = "0.3", features = ["rt", "net"] } //! ``` //! //! ## Working With Tasks @@ -153,7 +102,7 @@ //! * Functions for [running blocking operations][blocking] in an asynchronous //! task context. //! -//! The [`tokio::task`] module is present only when the "rt-core" feature flag +//! The [`tokio::task`] module is present only when the "rt" feature flag //! is enabled. //! //! [tasks]: task/index.html#what-are-tasks @@ -184,25 +133,26 @@ //! //! The [`tokio::time`] module provides utilities for tracking time and //! scheduling work. This includes functions for setting [timeouts][timeout] for -//! tasks, [delaying][delay] work to run in the future, or [repeating an operation at an +//! tasks, [sleeping][sleep] work to run in the future, or [repeating an operation at an //! interval][interval]. //! //! In order to use `tokio::time`, the "time" feature flag must be enabled. //! //! [`tokio::time`]: crate::time -//! [delay]: crate::time::delay_for() +//! [sleep]: crate::time::sleep() //! [interval]: crate::time::interval() //! [timeout]: crate::time::timeout() //! //! Finally, Tokio provides a _runtime_ for executing asynchronous tasks. Most //! applications can use the [`#[tokio::main]`][main] macro to run their code on the -//! Tokio runtime. In use-cases where manual control over the runtime is -//! required, the [`tokio::runtime`] module provides APIs for configuring and -//! managing runtimes. -//! -//! Using the runtime requires the "rt-core" or "rt-threaded" feature flags, to -//! enable the basic [single-threaded scheduler][rt-core] and the [thread-pool -//! scheduler][rt-threaded], respectively. See the [`runtime` module +//! Tokio runtime. However, this macro provides only basic configuration options. As +//! an alternative, the [`tokio::runtime`] module provides more powerful APIs for configuring +//! and managing runtimes. You should use that module if the `#[tokio::main]` macro doesn't +//! provide the functionality you need. +//! +//! Using the runtime requires the "rt" or "rt-multi-thread" feature flags, to +//! enable the basic [single-threaded scheduler][rt] and the [thread-pool +//! scheduler][rt-multi-thread], respectively. See the [`runtime` module //! documentation][rt-features] for details. In addition, the "macros" feature //! flag enables the `#[tokio::main]` and `#[tokio::test]` attributes. //! @@ -210,8 +160,8 @@ //! [`tokio::runtime`]: crate::runtime //! [`Builder`]: crate::runtime::Builder //! [`Runtime`]: crate::runtime::Runtime -//! [rt-core]: runtime/index.html#basic-scheduler -//! [rt-threaded]: runtime/index.html#threaded-scheduler +//! [rt]: runtime/index.html#basic-scheduler +//! [rt-multi-thread]: runtime/index.html#threaded-scheduler //! [rt-features]: runtime/index.html#runtime-scheduler //! //! ## CPU-bound tasks and blocking code @@ -268,8 +218,7 @@ //! the [`AsyncRead`], [`AsyncWrite`], and [`AsyncBufRead`] traits. In addition, //! when the "io-util" feature flag is enabled, it also provides combinators and //! functions for working with these traits, forming as an asynchronous -//! counterpart to [`std::io`]. When the "io-driver" feature flag is enabled, it -//! also provides utilities for library authors implementing I/O resources. +//! counterpart to [`std::io`]. //! //! Tokio also includes APIs for performing various kinds of I/O and interacting //! with the operating system asynchronously. These include: @@ -307,7 +256,7 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box<dyn std::error::Error>> { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; @@ -337,6 +286,50 @@ //! } //! } //! ``` +//! +//! ## Feature flags +//! +//! Tokio uses a set of [feature flags] to reduce the amount of compiled code. It +//! is possible to just enable certain features over others. By default, Tokio +//! does not enable any features but allows one to enable a subset for their use +//! case. Below is a list of the available feature flags. You may also notice +//! above each function, struct and trait there is listed one or more feature flags +//! that are required for that item to be used. If you are new to Tokio it is +//! recommended that you use the `full` feature flag which will enable all public APIs. +//! Beware though that this will pull in many extra dependencies that you may not +//! need. +//! +//! - `full`: Enables all Tokio public API features listed below. +//! - `rt`: Enables `tokio::spawn`, the basic (current thread) scheduler, +//! and non-scheduler utilities. +//! - `rt-multi-thread`: Enables the heavier, multi-threaded, work-stealing scheduler. +//! - `io-util`: Enables the IO based `Ext` traits. +//! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types. +//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`. +//! - `time`: Enables `tokio::time` types and allows the schedulers to enable +//! the built in timer. +//! - `process`: Enables `tokio::process` types. +//! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros. +//! - `sync`: Enables all `tokio::sync` types. +//! - `stream`: Enables optional `Stream` implementations for types within Tokio. +//! - `signal`: Enables all `tokio::signal` types. +//! - `fs`: Enables `tokio::fs` types. +//! - `test-util`: Enables testing based infrastructure for the Tokio runtime. +//! +//! _Note: `AsyncRead` and `AsyncWrite` traits do not require any features and are +//! always available._ +//! +//! ### Internal features +//! +//! These features do not expose any new API, but influence internal +//! implementation aspects of Tokio, and can pull in additional +//! dependencies. +//! +//! - `parking_lot`: As a potential optimization, use the _parking_lot_ crate's +//! synchronization primitives internally. MSRV may increase according to the +//! _parking_lot_ release in use. +//! +//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section // Includes re-exports used by macros. // @@ -350,8 +343,7 @@ cfg_fs! { pub mod fs; } -#[doc(hidden)] -pub mod future; +mod future; pub mod io; pub mod net; @@ -365,7 +357,12 @@ cfg_process! { pub mod process; } -pub mod runtime; +#[cfg(any(feature = "net", feature = "fs", feature = "io-std"))] +mod blocking; + +cfg_rt! { + pub mod runtime; +} pub(crate) mod coop; @@ -373,6 +370,13 @@ cfg_signal! { pub mod signal; } +cfg_signal_internal! { + #[cfg(not(feature = "signal"))] + #[allow(dead_code)] + #[allow(unreachable_pub)] + pub(crate) mod signal; +} + cfg_stream! { pub mod stream; } @@ -384,8 +388,8 @@ cfg_not_sync! { mod sync; } -cfg_rt_core! { - pub mod task; +pub mod task; +cfg_rt! { pub use task::spawn; } @@ -402,31 +406,31 @@ cfg_macros! { #[doc(hidden)] pub use tokio_macros::select_priv_declare_output_enum; - doc_rt_core! { - cfg_rt_threaded! { + cfg_rt! { + cfg_rt_multi_thread! { // This is the docs.rs case (with all features) so make sure macros // is included in doc(cfg). #[cfg(not(test))] // Work around for rust-lang/rust#62127 #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - pub use tokio_macros::main_threaded as main; + pub use tokio_macros::main; #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - pub use tokio_macros::test_threaded as test; + pub use tokio_macros::test; } - cfg_not_rt_threaded! { + cfg_not_rt_multi_thread! { #[cfg(not(test))] // Work around for rust-lang/rust#62127 - pub use tokio_macros::main_basic as main; - pub use tokio_macros::test_basic as test; + pub use tokio_macros::main_rt as main; + pub use tokio_macros::test_rt as test; } } - // Maintains old behavior - cfg_not_rt_core! { + // Always fail if rt is not enabled. + cfg_not_rt! { #[cfg(not(test))] - pub use tokio_macros::main; - pub use tokio_macros::test; + pub use tokio_macros::main_fail as main; + pub use tokio_macros::test_fail as test; } } diff --git a/src/loom/mocked.rs b/src/loom/mocked.rs index 7891395..367d59b 100644 --- a/src/loom/mocked.rs +++ b/src/loom/mocked.rs @@ -1,5 +1,32 @@ pub(crate) use loom::*; +pub(crate) mod sync { + + pub(crate) use loom::sync::MutexGuard; + + #[derive(Debug)] + pub(crate) struct Mutex<T>(loom::sync::Mutex<T>); + + #[allow(dead_code)] + impl<T> Mutex<T> { + #[inline] + pub(crate) fn new(t: T) -> Mutex<T> { + Mutex(loom::sync::Mutex::new(t)) + } + + #[inline] + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + self.0.lock().unwrap() + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + self.0.try_lock().ok() + } + } + pub(crate) use loom::sync::*; +} + pub(crate) mod rand { pub(crate) fn seed() -> u64 { 1 diff --git a/src/loom/mod.rs b/src/loom/mod.rs index 56a41f2..5957b53 100644 --- a/src/loom/mod.rs +++ b/src/loom/mod.rs @@ -1,6 +1,8 @@ //! This module abstracts over `loom` and `std::sync` depending on whether we //! are running tests or not. +#![allow(unused)] + #[cfg(not(all(test, loom)))] mod std; #[cfg(not(all(test, loom)))] diff --git a/src/loom/std/atomic_ptr.rs b/src/loom/std/atomic_ptr.rs index f7fd56c..236645f 100644 --- a/src/loom/std/atomic_ptr.rs +++ b/src/loom/std/atomic_ptr.rs @@ -1,5 +1,5 @@ use std::fmt; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; /// `AtomicPtr` providing an additional `load_unsync` function. pub(crate) struct AtomicPtr<T> { @@ -21,6 +21,12 @@ impl<T> Deref for AtomicPtr<T> { } } +impl<T> DerefMut for AtomicPtr<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + impl<T> fmt::Debug for AtomicPtr<T> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { self.deref().fmt(fmt) diff --git a/src/loom/std/atomic_u16.rs b/src/loom/std/atomic_u16.rs index 7039097..c1c5312 100644 --- a/src/loom/std/atomic_u16.rs +++ b/src/loom/std/atomic_u16.rs @@ -11,7 +11,7 @@ unsafe impl Send for AtomicU16 {} unsafe impl Sync for AtomicU16 {} impl AtomicU16 { - pub(crate) fn new(val: u16) -> AtomicU16 { + pub(crate) const fn new(val: u16) -> AtomicU16 { let inner = UnsafeCell::new(std::sync::atomic::AtomicU16::new(val)); AtomicU16 { inner } } diff --git a/src/loom/std/atomic_u32.rs b/src/loom/std/atomic_u32.rs index 6f786c5..61f95fb 100644 --- a/src/loom/std/atomic_u32.rs +++ b/src/loom/std/atomic_u32.rs @@ -11,7 +11,7 @@ unsafe impl Send for AtomicU32 {} unsafe impl Sync for AtomicU32 {} impl AtomicU32 { - pub(crate) fn new(val: u32) -> AtomicU32 { + pub(crate) const fn new(val: u32) -> AtomicU32 { let inner = UnsafeCell::new(std::sync::atomic::AtomicU32::new(val)); AtomicU32 { inner } } diff --git a/src/loom/std/atomic_u8.rs b/src/loom/std/atomic_u8.rs index 4fcd0df..408aea3 100644 --- a/src/loom/std/atomic_u8.rs +++ b/src/loom/std/atomic_u8.rs @@ -11,7 +11,7 @@ unsafe impl Send for AtomicU8 {} unsafe impl Sync for AtomicU8 {} impl AtomicU8 { - pub(crate) fn new(val: u8) -> AtomicU8 { + pub(crate) const fn new(val: u8) -> AtomicU8 { let inner = UnsafeCell::new(std::sync::atomic::AtomicU8::new(val)); AtomicU8 { inner } } diff --git a/src/loom/std/atomic_usize.rs b/src/loom/std/atomic_usize.rs index 0fe998f..0d5f36e 100644 --- a/src/loom/std/atomic_usize.rs +++ b/src/loom/std/atomic_usize.rs @@ -11,7 +11,7 @@ unsafe impl Send for AtomicUsize {} unsafe impl Sync for AtomicUsize {} impl AtomicUsize { - pub(crate) fn new(val: usize) -> AtomicUsize { + pub(crate) const fn new(val: usize) -> AtomicUsize { let inner = UnsafeCell::new(std::sync::atomic::AtomicUsize::new(val)); AtomicUsize { inner } } diff --git a/src/loom/std/mod.rs b/src/loom/std/mod.rs index 60ee56a..9525286 100644 --- a/src/loom/std/mod.rs +++ b/src/loom/std/mod.rs @@ -6,6 +6,7 @@ mod atomic_u32; mod atomic_u64; mod atomic_u8; mod atomic_usize; +mod mutex; #[cfg(feature = "parking_lot")] mod parking_lot; mod unsafe_cell; @@ -14,7 +15,12 @@ pub(crate) mod cell { pub(crate) use super::unsafe_cell::UnsafeCell; } -#[cfg(any(feature = "sync", feature = "io-driver"))] +#[cfg(any( + feature = "net", + feature = "process", + feature = "signal", + feature = "sync", +))] pub(crate) mod future { pub(crate) use crate::sync::AtomicWaker; } @@ -55,9 +61,10 @@ pub(crate) mod sync { #[cfg(not(feature = "parking_lot"))] #[allow(unused_imports)] - pub(crate) use std::sync::{ - Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, - }; + pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult}; + + #[cfg(not(feature = "parking_lot"))] + pub(crate) use crate::loom::std::mutex::Mutex; pub(crate) mod atomic { pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr; @@ -72,12 +79,12 @@ pub(crate) mod sync { } pub(crate) mod sys { - #[cfg(feature = "rt-threaded")] + #[cfg(feature = "rt-multi-thread")] pub(crate) fn num_cpus() -> usize { usize::max(1, num_cpus::get()) } - #[cfg(not(feature = "rt-threaded"))] + #[cfg(not(feature = "rt-multi-thread"))] pub(crate) fn num_cpus() -> usize { 1 } diff --git a/src/loom/std/mutex.rs b/src/loom/std/mutex.rs new file mode 100644 index 0000000..bf14d62 --- /dev/null +++ b/src/loom/std/mutex.rs @@ -0,0 +1,31 @@ +use std::sync::{self, MutexGuard, TryLockError}; + +/// Adapter for `std::Mutex` that removes the poisoning aspects +// from its api +#[derive(Debug)] +pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>); + +#[allow(dead_code)] +impl<T> Mutex<T> { + #[inline] + pub(crate) fn new(t: T) -> Mutex<T> { + Mutex(sync::Mutex::new(t)) + } + + #[inline] + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + match self.0.lock() { + Ok(guard) => guard, + Err(p_err) => p_err.into_inner(), + } + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + match self.0.try_lock() { + Ok(guard) => Some(guard), + Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()), + Err(TryLockError::WouldBlock) => None, + } + } +} diff --git a/src/loom/std/parking_lot.rs b/src/loom/std/parking_lot.rs index 25d94af..c03190f 100644 --- a/src/loom/std/parking_lot.rs +++ b/src/loom/std/parking_lot.rs @@ -3,7 +3,7 @@ //! //! This can be extended to additional types/methods as required. -use std::sync::{LockResult, TryLockError, TryLockResult}; +use std::sync::LockResult; use std::time::Duration; // Types that do not need wrapping @@ -27,16 +27,20 @@ impl<T> Mutex<T> { } #[inline] - pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> { - Ok(self.0.lock()) + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "parking_lot",))))] + pub(crate) const fn const_new(t: T) -> Mutex<T> { + Mutex(parking_lot::const_mutex(t)) } #[inline] - pub(crate) fn try_lock(&self) -> TryLockResult<MutexGuard<'_, T>> { - match self.0.try_lock() { - Some(guard) => Ok(guard), - None => Err(TryLockError::WouldBlock), - } + pub(crate) fn lock(&self) -> MutexGuard<'_, T> { + self.0.lock() + } + + #[inline] + pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> { + self.0.try_lock() } // Note: Additional methods `is_poisoned` and `into_inner`, can be diff --git a/src/loom/std/unsafe_cell.rs b/src/loom/std/unsafe_cell.rs index f2b03d8..66c1d79 100644 --- a/src/loom/std/unsafe_cell.rs +++ b/src/loom/std/unsafe_cell.rs @@ -2,7 +2,7 @@ pub(crate) struct UnsafeCell<T>(std::cell::UnsafeCell<T>); impl<T> UnsafeCell<T> { - pub(crate) fn new(data: T) -> UnsafeCell<T> { + pub(crate) const fn new(data: T) -> UnsafeCell<T> { UnsafeCell(std::cell::UnsafeCell::new(data)) } diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs index 4b77544..2792911 100644 --- a/src/macros/cfg.rs +++ b/src/macros/cfg.rs @@ -1,97 +1,30 @@ #![allow(unused_macros)] -macro_rules! cfg_resource_drivers { - ($($item:item)*) => { - $( - #[cfg(any(feature = "io-driver", feature = "time"))] - $item - )* - } -} - -macro_rules! cfg_blocking { - ($($item:item)*) => { - $( - #[cfg(feature = "blocking")] - #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))] - $item - )* - } -} - -/// Enables blocking API internals -macro_rules! cfg_blocking_impl { - ($($item:item)*) => { - $( - #[cfg(any( - feature = "blocking", - feature = "fs", - feature = "dns", - feature = "io-std", - feature = "rt-threaded", - ))] - $item - )* - } -} - -/// Enables blocking API internals -macro_rules! cfg_blocking_impl_or_task { - ($($item:item)*) => { - $( - #[cfg(any( - feature = "blocking", - feature = "fs", - feature = "dns", - feature = "io-std", - feature = "rt-threaded", - feature = "task", - ))] - $item - )* - } -} - /// Enables enter::block_on macro_rules! cfg_block_on { ($($item:item)*) => { $( #[cfg(any( - feature = "blocking", feature = "fs", - feature = "dns", + feature = "net", feature = "io-std", - feature = "rt-core", + feature = "rt", ))] $item )* } } -/// Enables blocking API internals -macro_rules! cfg_not_blocking_impl { - ($($item:item)*) => { - $( - #[cfg(not(any( - feature = "blocking", - feature = "fs", - feature = "dns", - feature = "io-std", - feature = "rt-threaded", - )))] - $item - )* - } -} - /// Enables internal `AtomicWaker` impl macro_rules! cfg_atomic_waker_impl { ($($item:item)*) => { $( #[cfg(any( - feature = "io-driver", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", feature = "time", - all(feature = "rt-core", feature = "rt-util") ))] #[cfg(not(loom))] $item @@ -99,16 +32,6 @@ macro_rules! cfg_atomic_waker_impl { } } -macro_rules! cfg_dns { - ($($item:item)*) => { - $( - #[cfg(feature = "dns")] - #[cfg_attr(docsrs, doc(cfg(feature = "dns")))] - $item - )* - } -} - macro_rules! cfg_fs { ($($item:item)*) => { $( @@ -128,8 +51,16 @@ macro_rules! cfg_io_blocking { macro_rules! cfg_io_driver { ($($item:item)*) => { $( - #[cfg(feature = "io-driver")] - #[cfg_attr(docsrs, doc(cfg(feature = "io-driver")))] + #[cfg(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + ))] + #[cfg_attr(docsrs, doc(cfg(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + ))))] $item )* } @@ -138,7 +69,20 @@ macro_rules! cfg_io_driver { macro_rules! cfg_not_io_driver { ($($item:item)*) => { $( - #[cfg(not(feature = "io-driver"))] + #[cfg(not(any( + feature = "net", + feature = "process", + all(unix, feature = "signal"), + )))] + $item + )* + } +} + +macro_rules! cfg_io_readiness { + ($($item:item)*) => { + $( + #[cfg(feature = "net")] $item )* } @@ -193,115 +137,142 @@ macro_rules! cfg_macros { } } -macro_rules! cfg_process { +macro_rules! cfg_net { ($($item:item)*) => { $( - #[cfg(feature = "process")] - #[cfg_attr(docsrs, doc(cfg(feature = "process")))] - #[cfg(not(loom))] + #[cfg(feature = "net")] + #[cfg_attr(docsrs, doc(cfg(feature = "net")))] $item )* } } -macro_rules! cfg_signal { +macro_rules! cfg_net_unix { ($($item:item)*) => { $( - #[cfg(feature = "signal")] - #[cfg_attr(docsrs, doc(cfg(feature = "signal")))] - #[cfg(not(loom))] + #[cfg(all(unix, feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(feature = "net")))] $item )* } } -macro_rules! cfg_stream { +macro_rules! cfg_process { ($($item:item)*) => { $( - #[cfg(feature = "stream")] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + #[cfg(feature = "process")] + #[cfg_attr(docsrs, doc(cfg(feature = "process")))] + #[cfg(not(loom))] $item )* } } -macro_rules! cfg_sync { +macro_rules! cfg_process_driver { + ($($item:item)*) => { + #[cfg(unix)] + #[cfg(not(loom))] + cfg_process! { $($item)* } + } +} + +macro_rules! cfg_not_process_driver { ($($item:item)*) => { $( - #[cfg(feature = "sync")] - #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] + #[cfg(not(all(unix, not(loom), feature = "process")))] $item )* } } -macro_rules! cfg_not_sync { +macro_rules! cfg_signal { ($($item:item)*) => { - $( #[cfg(not(feature = "sync"))] $item )* + $( + #[cfg(feature = "signal")] + #[cfg_attr(docsrs, doc(cfg(feature = "signal")))] + #[cfg(not(loom))] + $item + )* } } -macro_rules! cfg_rt_core { +macro_rules! cfg_signal_internal { ($($item:item)*) => { $( - #[cfg(feature = "rt-core")] + #[cfg(any(feature = "signal", all(unix, feature = "process")))] + #[cfg(not(loom))] $item )* } } -macro_rules! doc_rt_core { +macro_rules! cfg_not_signal_internal { ($($item:item)*) => { $( - #[cfg(feature = "rt-core")] - #[cfg_attr(docsrs, doc(cfg(feature = "rt-core")))] + #[cfg(any(loom, not(unix), not(any(feature = "signal", all(unix, feature = "process")))))] $item )* } } -macro_rules! cfg_not_rt_core { +macro_rules! cfg_stream { ($($item:item)*) => { - $( #[cfg(not(feature = "rt-core"))] $item )* + $( + #[cfg(feature = "stream")] + #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + $item + )* } } -macro_rules! cfg_rt_threaded { +macro_rules! cfg_sync { ($($item:item)*) => { $( - #[cfg(feature = "rt-threaded")] - #[cfg_attr(docsrs, doc(cfg(feature = "rt-threaded")))] + #[cfg(feature = "sync")] + #[cfg_attr(docsrs, doc(cfg(feature = "sync")))] $item )* } } -macro_rules! cfg_rt_util { +macro_rules! cfg_not_sync { + ($($item:item)*) => { + $( #[cfg(not(feature = "sync"))] $item )* + } +} + +macro_rules! cfg_rt { ($($item:item)*) => { $( - #[cfg(feature = "rt-util")] - #[cfg_attr(docsrs, doc(cfg(feature = "rt-util")))] + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] $item )* } } -macro_rules! cfg_not_rt_threaded { +macro_rules! cfg_not_rt { ($($item:item)*) => { - $( #[cfg(not(feature = "rt-threaded"))] $item )* + $( #[cfg(not(feature = "rt"))] $item )* } } -macro_rules! cfg_tcp { +macro_rules! cfg_rt_multi_thread { ($($item:item)*) => { $( - #[cfg(feature = "tcp")] - #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))] + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] $item )* } } +macro_rules! cfg_not_rt_multi_thread { + ($($item:item)*) => { + $( #[cfg(not(feature = "rt-multi-thread"))] $item )* + } +} + macro_rules! cfg_test_util { ($($item:item)*) => { $( @@ -334,36 +305,6 @@ macro_rules! cfg_not_time { } } -macro_rules! cfg_udp { - ($($item:item)*) => { - $( - #[cfg(feature = "udp")] - #[cfg_attr(docsrs, doc(cfg(feature = "udp")))] - $item - )* - } -} - -macro_rules! cfg_uds { - ($($item:item)*) => { - $( - #[cfg(all(unix, feature = "uds"))] - #[cfg_attr(docsrs, doc(cfg(feature = "uds")))] - $item - )* - } -} - -macro_rules! cfg_unstable { - ($($item:item)*) => { - $( - #[cfg(tokio_unstable)] - #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - $item - )* - } -} - macro_rules! cfg_trace { ($($item:item)*) => { $( @@ -387,16 +328,15 @@ macro_rules! cfg_coop { ($($item:item)*) => { $( #[cfg(any( - feature = "blocking", - feature = "dns", feature = "fs", - feature = "io-driver", feature = "io-std", + feature = "net", feature = "process", - feature = "rt-core", + feature = "rt", + feature = "signal", feature = "sync", feature = "stream", - feature = "time" + feature = "time", ))] $item )* diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 2643c36..b0af521 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -16,7 +16,7 @@ mod ready; mod thread_local; #[macro_use] -#[cfg(feature = "rt-core")] +#[cfg(feature = "rt")] pub(crate) mod scoped_tls; cfg_macros! { diff --git a/src/macros/scoped_tls.rs b/src/macros/scoped_tls.rs index 886f9d4..a00aae2 100644 --- a/src/macros/scoped_tls.rs +++ b/src/macros/scoped_tls.rs @@ -23,9 +23,7 @@ macro_rules! scoped_thread_local { /// Type representing a thread local storage key corresponding to a reference /// to the type parameter `T`. pub(crate) struct ScopedKey<T> { - #[doc(hidden)] pub(crate) inner: &'static LocalKey<Cell<*const ()>>, - #[doc(hidden)] pub(crate) _marker: marker::PhantomData<T>, } diff --git a/src/macros/select.rs b/src/macros/select.rs index 52c8fdd..b63abdd 100644 --- a/src/macros/select.rs +++ b/src/macros/select.rs @@ -63,9 +63,9 @@ /// Given that `if` preconditions are used to disable `select!` branches, some /// caution must be used to avoid missing values. /// -/// For example, here is **incorrect** usage of `delay` with `if`. The objective +/// For example, here is **incorrect** usage of `sleep` with `if`. The objective /// is to repeatedly run an asynchronous task for up to 50 milliseconds. -/// However, there is a potential for the `delay` completion to be missed. +/// However, there is a potential for the `sleep` completion to be missed. /// /// ```no_run /// use tokio::time::{self, Duration}; @@ -76,11 +76,11 @@ /// /// #[tokio::main] /// async fn main() { -/// let mut delay = time::delay_for(Duration::from_millis(50)); +/// let mut sleep = time::sleep(Duration::from_millis(50)); /// -/// while !delay.is_elapsed() { +/// while !sleep.is_elapsed() { /// tokio::select! { -/// _ = &mut delay, if !delay.is_elapsed() => { +/// _ = &mut sleep, if !sleep.is_elapsed() => { /// println!("operation timed out"); /// } /// _ = some_async_work() => { @@ -91,11 +91,11 @@ /// } /// ``` /// -/// In the above example, `delay.is_elapsed()` may return `true` even if -/// `delay.poll()` never returned `Ready`. This opens up a potential race -/// condition where `delay` expires between the `while !delay.is_elapsed()` +/// In the above example, `sleep.is_elapsed()` may return `true` even if +/// `sleep.poll()` never returned `Ready`. This opens up a potential race +/// condition where `sleep` expires between the `while !sleep.is_elapsed()` /// check and the call to `select!` resulting in the `some_async_work()` call to -/// run uninterrupted despite the delay having elapsed. +/// run uninterrupted despite the sleep having elapsed. /// /// One way to write the above example without the race would be: /// @@ -103,17 +103,17 @@ /// use tokio::time::{self, Duration}; /// /// async fn some_async_work() { -/// # time::delay_for(Duration::from_millis(10)).await; +/// # time::sleep(Duration::from_millis(10)).await; /// // do work /// } /// /// #[tokio::main] /// async fn main() { -/// let mut delay = time::delay_for(Duration::from_millis(50)); +/// let mut sleep = time::sleep(Duration::from_millis(50)); /// /// loop { /// tokio::select! { -/// _ = &mut delay => { +/// _ = &mut sleep => { /// println!("operation timed out"); /// break; /// } @@ -226,7 +226,7 @@ /// #[tokio::main] /// async fn main() { /// let mut stream = stream::iter(vec![1, 2, 3]); -/// let mut delay = time::delay_for(Duration::from_secs(1)); +/// let mut sleep = time::sleep(Duration::from_secs(1)); /// /// loop { /// tokio::select! { @@ -237,7 +237,7 @@ /// break; /// } /// } -/// _ = &mut delay => { +/// _ = &mut sleep => { /// println!("timeout"); /// break; /// } @@ -366,6 +366,7 @@ macro_rules! select { } match branch { $( + #[allow(unreachable_code)] $crate::count!( $($skip)* ) => { // First, if the future has previously been // disabled, do not poll it again. This is done @@ -403,6 +404,7 @@ macro_rules! select { // The future returned a value, check if matches // the specified pattern. #[allow(unused_variables)] + #[allow(unused_mut)] match &out { $bind => {} _ => continue, diff --git a/src/macros/support.rs b/src/macros/support.rs index fc1cdfc..7f11bc6 100644 --- a/src/macros/support.rs +++ b/src/macros/support.rs @@ -1,5 +1,6 @@ cfg_macros! { - pub use crate::future::{maybe_done, poll_fn}; + pub use crate::future::poll_fn; + pub use crate::future::maybe_done::maybe_done; pub use crate::util::thread_rng_n; } diff --git a/src/net/addr.rs b/src/net/addr.rs index 5ba898a..7cbe531 100644 --- a/src/net/addr.rs +++ b/src/net/addr.rs @@ -9,7 +9,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV /// /// Implementations of `ToSocketAddrs` for string types require a DNS lookup. /// These implementations are only provided when Tokio is used with the -/// **`dns`** feature flag. +/// **`net`** feature flag. /// /// # Calling /// @@ -23,6 +23,15 @@ pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {} type ReadyFuture<T> = future::Ready<io::Result<T>>; +cfg_net! { + pub(crate) fn to_socket_addrs<T>(arg: T) -> T::Future + where + T: ToSocketAddrs, + { + arg.to_socket_addrs(sealed::Internal) + } +} + // ===== impl &impl ToSocketAddrs ===== impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {} @@ -34,8 +43,8 @@ where type Iter = T::Iter; type Future = T::Future; - fn to_socket_addrs(&self) -> Self::Future { - (**self).to_socket_addrs() + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (**self).to_socket_addrs(sealed::Internal) } } @@ -47,7 +56,7 @@ impl sealed::ToSocketAddrsPriv for SocketAddr { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { let iter = Some(*self).into_iter(); future::ok(iter) } @@ -61,8 +70,8 @@ impl sealed::ToSocketAddrsPriv for SocketAddrV4 { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { - SocketAddr::V4(*self).to_socket_addrs() + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + SocketAddr::V4(*self).to_socket_addrs(sealed::Internal) } } @@ -74,8 +83,8 @@ impl sealed::ToSocketAddrsPriv for SocketAddrV6 { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { - SocketAddr::V6(*self).to_socket_addrs() + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + SocketAddr::V6(*self).to_socket_addrs(sealed::Internal) } } @@ -87,7 +96,7 @@ impl sealed::ToSocketAddrsPriv for (IpAddr, u16) { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { let iter = Some(SocketAddr::from(*self)).into_iter(); future::ok(iter) } @@ -101,9 +110,9 @@ impl sealed::ToSocketAddrsPriv for (Ipv4Addr, u16) { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { let (ip, port) = *self; - SocketAddrV4::new(ip, port).to_socket_addrs() + SocketAddrV4::new(ip, port).to_socket_addrs(sealed::Internal) } } @@ -115,9 +124,9 @@ impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) { type Iter = std::option::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { let (ip, port) = *self; - SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs() + SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs(sealed::Internal) } } @@ -129,13 +138,13 @@ impl sealed::ToSocketAddrsPriv for &[SocketAddr] { type Iter = std::vec::IntoIter<SocketAddr>; type Future = ReadyFuture<Self::Iter>; - fn to_socket_addrs(&self) -> Self::Future { + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { let iter = self.to_vec().into_iter(); future::ok(iter) } } -cfg_dns! { +cfg_net! { // ===== impl str ===== impl ToSocketAddrs for str {} @@ -144,23 +153,23 @@ cfg_dns! { type Iter = sealed::OneOrMore; type Future = sealed::MaybeReady; - fn to_socket_addrs(&self) -> Self::Future { - use crate::runtime::spawn_blocking; + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + use crate::blocking::spawn_blocking; use sealed::MaybeReady; // First check if the input parses as a socket address let res: Result<SocketAddr, _> = self.parse(); if let Ok(addr) = res { - return MaybeReady::Ready(Some(addr)); + return MaybeReady(sealed::State::Ready(Some(addr))); } // Run DNS lookup on the blocking pool let s = self.to_owned(); - MaybeReady::Blocking(spawn_blocking(move || { + MaybeReady(sealed::State::Blocking(spawn_blocking(move || { std::net::ToSocketAddrs::to_socket_addrs(&s) - })) + }))) } } @@ -172,8 +181,8 @@ cfg_dns! { type Iter = sealed::OneOrMore; type Future = sealed::MaybeReady; - fn to_socket_addrs(&self) -> Self::Future { - use crate::runtime::spawn_blocking; + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + use crate::blocking::spawn_blocking; use sealed::MaybeReady; let (host, port) = *self; @@ -183,21 +192,34 @@ cfg_dns! { let addr = SocketAddrV4::new(addr, port); let addr = SocketAddr::V4(addr); - return MaybeReady::Ready(Some(addr)); + return MaybeReady(sealed::State::Ready(Some(addr))); } if let Ok(addr) = host.parse::<Ipv6Addr>() { let addr = SocketAddrV6::new(addr, port, 0, 0); let addr = SocketAddr::V6(addr); - return MaybeReady::Ready(Some(addr)); + return MaybeReady(sealed::State::Ready(Some(addr))); } let host = host.to_owned(); - MaybeReady::Blocking(spawn_blocking(move || { + MaybeReady(sealed::State::Blocking(spawn_blocking(move || { std::net::ToSocketAddrs::to_socket_addrs(&(&host[..], port)) - })) + }))) + } + } + + // ===== impl (String, u16) ===== + + impl ToSocketAddrs for (String, u16) {} + + impl sealed::ToSocketAddrsPriv for (String, u16) { + type Iter = sealed::OneOrMore; + type Future = sealed::MaybeReady; + + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (self.0.as_str(), self.1).to_socket_addrs(sealed::Internal) } } @@ -209,8 +231,8 @@ cfg_dns! { type Iter = <str as sealed::ToSocketAddrsPriv>::Iter; type Future = <str as sealed::ToSocketAddrsPriv>::Future; - fn to_socket_addrs(&self) -> Self::Future { - (&self[..]).to_socket_addrs() + fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + (&self[..]).to_socket_addrs(sealed::Internal) } } } @@ -224,27 +246,31 @@ pub(crate) mod sealed { use std::io; use std::net::SocketAddr; - cfg_dns! { - use crate::task::JoinHandle; - - use std::option; - use std::pin::Pin; - use std::task::{Context, Poll}; - use std::vec; - } - #[doc(hidden)] pub trait ToSocketAddrsPriv { type Iter: Iterator<Item = SocketAddr> + Send + 'static; type Future: Future<Output = io::Result<Self::Iter>> + Send + 'static; - fn to_socket_addrs(&self) -> Self::Future; + fn to_socket_addrs(&self, internal: Internal) -> Self::Future; } - cfg_dns! { + #[allow(missing_debug_implementations)] + pub struct Internal; + + cfg_net! { + use crate::blocking::JoinHandle; + + use std::option; + use std::pin::Pin; + use std::task::{Context, Poll}; + use std::vec; + #[doc(hidden)] #[derive(Debug)] - pub enum MaybeReady { + pub struct MaybeReady(pub(super) State); + + #[derive(Debug)] + pub(super) enum State { Ready(Option<SocketAddr>), Blocking(JoinHandle<io::Result<vec::IntoIter<SocketAddr>>>), } @@ -260,12 +286,12 @@ pub(crate) mod sealed { type Output = io::Result<OneOrMore>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - match *self { - MaybeReady::Ready(ref mut i) => { + match self.0 { + State::Ready(ref mut i) => { let iter = OneOrMore::One(i.take().into_iter()); Poll::Ready(Ok(iter)) } - MaybeReady::Blocking(ref mut rx) => { + State::Blocking(ref mut rx) => { let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More); Poll::Ready(res) diff --git a/src/net/lookup_host.rs b/src/net/lookup_host.rs index 3098b46..2886184 100644 --- a/src/net/lookup_host.rs +++ b/src/net/lookup_host.rs @@ -1,5 +1,5 @@ -cfg_dns! { - use crate::net::addr::ToSocketAddrs; +cfg_net! { + use crate::net::addr::{self, ToSocketAddrs}; use std::io; use std::net::SocketAddr; @@ -33,6 +33,6 @@ cfg_dns! { where T: ToSocketAddrs { - host.to_socket_addrs().await + addr::to_socket_addrs(host).await } } diff --git a/src/net/mod.rs b/src/net/mod.rs index eb24ac0..b7365e6 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -23,27 +23,26 @@ //! [`UnixDatagram`]: UnixDatagram mod addr; +#[cfg(feature = "net")] +pub(crate) use addr::to_socket_addrs; pub use addr::ToSocketAddrs; -cfg_dns! { +cfg_net! { mod lookup_host; pub use lookup_host::lookup_host; -} -cfg_tcp! { pub mod tcp; pub use tcp::listener::TcpListener; + pub use tcp::socket::TcpSocket; pub use tcp::stream::TcpStream; -} -cfg_udp! { pub mod udp; pub use udp::socket::UdpSocket; } -cfg_uds! { +cfg_net_unix! { pub mod unix; - pub use unix::datagram::UnixDatagram; + pub use unix::datagram::socket::UnixDatagram; pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; } diff --git a/src/net/tcp/incoming.rs b/src/net/tcp/incoming.rs deleted file mode 100644 index 062be1e..0000000 --- a/src/net/tcp/incoming.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::net::tcp::{TcpListener, TcpStream}; - -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Stream returned by the `TcpListener::incoming` function representing the -/// stream of sockets received from a listener. -#[must_use = "streams do nothing unless polled"] -#[derive(Debug)] -pub struct Incoming<'a> { - inner: &'a mut TcpListener, -} - -impl Incoming<'_> { - pub(crate) fn new(listener: &mut TcpListener) -> Incoming<'_> { - Incoming { inner: listener } - } - - /// Attempts to poll `TcpStream` by polling inner `TcpListener` to accept - /// connection. - /// - /// If `TcpListener` isn't ready yet, `Poll::Pending` is returned and - /// current task will be notified by a waker. - pub fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<io::Result<TcpStream>> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Ok(socket)) - } -} - -#[cfg(feature = "stream")] -impl crate::stream::Stream for Incoming<'_> { - type Item = io::Result<TcpStream>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Some(Ok(socket))) - } -} diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index fd79b25..3f9bca0 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -1,7 +1,6 @@ -use crate::future::poll_fn; use crate::io::PollEvented; -use crate::net::tcp::{Incoming, TcpStream}; -use crate::net::ToSocketAddrs; +use crate::net::tcp::TcpStream; +use crate::net::{to_socket_addrs, ToSocketAddrs}; use std::convert::TryFrom; use std::fmt; @@ -9,7 +8,7 @@ use std::io; use std::net::{self, SocketAddr}; use std::task::{Context, Poll}; -cfg_tcp! { +cfg_net! { /// A TCP socket server, listening for connections. /// /// You can accept a new connection by using the [`accept`](`TcpListener::accept`) method. Alternatively `TcpListener` @@ -40,7 +39,7 @@ cfg_tcp! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// loop { /// let (socket, _) = listener.accept().await?; @@ -81,7 +80,7 @@ impl TcpListener { /// method. /// /// The address type can be any implementor of the [`ToSocketAddrs`] trait. - /// Note that strings only implement this trait when the **`dns`** feature + /// Note that strings only implement this trait when the **`net`** feature /// is enabled, as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, bind will be attempted with each of @@ -110,27 +109,8 @@ impl TcpListener { /// Ok(()) /// } /// ``` - /// - /// Without the `dns` feature: - /// - /// ```no_run - /// use tokio::net::TcpListener; - /// use std::net::Ipv4Addr; - /// - /// use std::io; - /// - /// #[tokio::main] - /// async fn main() -> io::Result<()> { - /// let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 2345)).await?; - /// - /// // use the listener - /// - /// # let _ = listener; - /// Ok(()) - /// } - /// ``` pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> { - let addrs = addr.to_socket_addrs().await?; + let addrs = to_socket_addrs(addr).await?; let mut last_err = None; @@ -150,7 +130,7 @@ impl TcpListener { } fn bind_addr(addr: SocketAddr) -> io::Result<TcpListener> { - let listener = mio::net::TcpListener::bind(&addr)?; + let listener = mio::net::TcpListener::bind(addr)?; TcpListener::new(listener) } @@ -171,7 +151,7 @@ impl TcpListener { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// match listener.accept().await { /// Ok((_socket, addr)) => println!("new client: {:?}", addr), @@ -181,66 +161,53 @@ impl TcpListener { /// Ok(()) /// } /// ``` - pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { - poll_fn(|cx| self.poll_accept(cx)).await + pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + let (mio, addr) = self + .io + .async_io(mio::Interest::READABLE, |sock| sock.accept()) + .await?; + + let stream = TcpStream::new(mio)?; + Ok((stream, addr)) } - /// Attempts to poll `SocketAddr` and `TcpStream` bound to this address. + /// Polls to accept a new incoming connection to this listener. /// - /// In case if I/O resource isn't ready yet, `Poll::Pending` is returned and + /// If there is no connection to accept, `Poll::Pending` is returned and the /// current task will be notified by a waker. - pub fn poll_accept( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<io::Result<(TcpStream, SocketAddr)>> { - let (io, addr) = ready!(self.poll_accept_std(cx))?; - - let io = mio::net::TcpStream::from_stream(io)?; - let io = TcpStream::new(io)?; - - Poll::Ready(Ok((io, addr))) - } - - fn poll_accept_std( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<io::Result<(net::TcpStream, SocketAddr)>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + /// + /// When ready, the most recent task that called `poll_accept` is notified. + /// The caller is responsble to ensure that `poll_accept` is called from a + /// single task. Failing to do this could result in tasks hanging. + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; - match self.io.get_ref().accept_std() { - Ok(pair) => Poll::Ready(Ok(pair)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending + match self.io.get_ref().accept() { + Ok((io, addr)) => { + let io = TcpStream::new(io)?; + return Poll::Ready(Ok((io, addr))); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), } - Err(e) => Poll::Ready(Err(e)), } } /// Creates a new TCP listener from the standard library's TCP listener. /// - /// This method can be used when the `Handle::tcp_listen` method isn't - /// sufficient because perhaps some more configuration is needed in terms of - /// before the calls to `bind` and `listen`. + /// This function is intended to be used to wrap a TCP listener from the + /// standard library in the Tokio equivalent. The conversion assumes nothing + /// about the underlying listener; it is left up to the user to set it in + /// non-blocking mode. /// /// This API is typically paired with the `net2` crate and the `TcpBuilder` /// type to build up and customize a listener before it's shipped off to the /// backing event loop. This allows configuration of options like /// `SO_REUSEPORT`, binding to multiple addresses, etc. /// - /// The `addr` argument here is one of the addresses that `listener` is - /// bound to and the listener will only be guaranteed to accept connections - /// of the same address type currently. - /// - /// The platform specific behavior of this function looks like: - /// - /// * On Unix, the socket is placed into nonblocking mode and connections - /// can be accepted as normal - /// - /// * On Windows, the address is stored internally and all future accepts - /// will only be for the same IP version as `addr` specified. That is, if - /// `addr` is an IPv4 address then all sockets accepted will be IPv4 as - /// well (same for IPv6). /// /// # Examples /// @@ -262,14 +229,14 @@ impl TcpListener { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> { - let io = mio::net::TcpListener::from_std(listener)?; + let io = mio::net::TcpListener::from_std(listener); let io = PollEvented::new(io)?; Ok(TcpListener { io }) } - fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> { + pub(crate) fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> { let io = PollEvented::new(listener)?; Ok(TcpListener { io }) } @@ -301,46 +268,6 @@ impl TcpListener { self.io.get_ref().local_addr() } - /// Returns a stream over the connections being received on this listener. - /// - /// Note that `TcpListener` also directly implements `Stream`. - /// - /// The returned stream will never return `None` and will also not yield the - /// peer's `SocketAddr` structure. Iterating over it is equivalent to - /// calling accept in a loop. - /// - /// # Errors - /// - /// Note that accepting a connection can lead to various errors and not all - /// of them are necessarily fatal ‒ for example having too many open file - /// descriptors or the other side closing the connection while it waits in - /// an accept queue. These would terminate the stream if not handled in any - /// way. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::{net::TcpListener, stream::StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); - /// let mut incoming = listener.incoming(); - /// - /// while let Some(stream) = incoming.next().await { - /// match stream { - /// Ok(stream) => { - /// println!("new client!"); - /// } - /// Err(e) => { /* connection failed */ } - /// } - /// } - /// } - /// ``` - pub fn incoming(&mut self) -> Incoming<'_> { - Incoming::new(self) - } - /// Gets the value of the `IP_TTL` option for this socket. /// /// For more information about this option, see [`set_ttl`]. @@ -398,29 +325,12 @@ impl TcpListener { impl crate::stream::Stream for TcpListener { type Item = io::Result<TcpStream>; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Self::Item>> { + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let (socket, _) = ready!(self.poll_accept(cx))?; Poll::Ready(Some(Ok(socket))) } } -impl TryFrom<TcpListener> for mio::net::TcpListener { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. - /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: TcpListener) -> Result<Self, Self::Error> { - value.io.into_inner() - } -} - impl TryFrom<net::TcpListener> for TcpListener { type Error = io::Error; @@ -453,14 +363,12 @@ mod sys { #[cfg(windows)] mod sys { - // TODO: let's land these upstream with mio and then we can add them here. - // - // use std::os::windows::prelude::*; - // use super::{TcpListener; - // - // impl AsRawHandle for TcpListener { - // fn as_raw_handle(&self) -> RawHandle { - // self.listener.io().as_raw_handle() - // } - // } + use super::TcpListener; + use std::os::windows::prelude::*; + + impl AsRawSocket for TcpListener { + fn as_raw_socket(&self) -> RawSocket { + self.io.get_ref().as_raw_socket() + } + } } diff --git a/src/net/tcp/mod.rs b/src/net/tcp/mod.rs index 7ad36eb..7f0f6d9 100644 --- a/src/net/tcp/mod.rs +++ b/src/net/tcp/mod.rs @@ -1,10 +1,8 @@ //! TCP utility types pub(crate) mod listener; -pub(crate) use listener::TcpListener; -mod incoming; -pub use incoming::Incoming; +pub(crate) mod socket; mod split; pub use split::{ReadHalf, WriteHalf}; diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs new file mode 100644 index 0000000..5b0f802 --- /dev/null +++ b/src/net/tcp/socket.rs @@ -0,0 +1,349 @@ +use crate::net::{TcpListener, TcpStream}; + +use std::fmt; +use std::io; +use std::net::SocketAddr; + +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; + +/// A TCP socket that has not yet been converted to a `TcpStream` or +/// `TcpListener`. +/// +/// `TcpSocket` wraps an operating system socket and enables the caller to +/// configure the socket before establishing a TCP connection or accepting +/// inbound connections. The caller is able to set socket option and explicitly +/// bind the socket with a socket address. +/// +/// The underlying socket is closed when the `TcpSocket` value is dropped. +/// +/// `TcpSocket` should only be used directly if the default configuration used +/// by `TcpStream::connect` and `TcpListener::bind` does not meet the required +/// use case. +/// +/// Calling `TcpStream::connect("127.0.0.1:8080")` is equivalent to: +/// +/// ```no_run +/// use tokio::net::TcpSocket; +/// +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let addr = "127.0.0.1:8080".parse().unwrap(); +/// +/// let socket = TcpSocket::new_v4()?; +/// let stream = socket.connect(addr).await?; +/// # drop(stream); +/// +/// Ok(()) +/// } +/// ``` +/// +/// Calling `TcpListener::bind("127.0.0.1:8080")` is equivalent to: +/// +/// ```no_run +/// use tokio::net::TcpSocket; +/// +/// use std::io; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let addr = "127.0.0.1:8080".parse().unwrap(); +/// +/// let socket = TcpSocket::new_v4()?; +/// // On platforms with Berkeley-derived sockets, this allows to quickly +/// // rebind a socket, without needing to wait for the OS to clean up the +/// // previous one. +/// // +/// // On Windows, this allows rebinding sockets which are actively in use, +/// // which allows “socket hijacking”, so we explicitly don't set it here. +/// // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse +/// socket.set_reuseaddr(true)?; +/// socket.bind(addr)?; +/// +/// let listener = socket.listen(1024)?; +/// # drop(listener); +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting socket options not explicitly provided by `TcpSocket` may be done by +/// accessing the `RawFd`/`RawSocket` using [`AsRawFd`]/[`AsRawSocket`] and +/// setting the option with a crate like [`socket2`]. +/// +/// [`RawFd`]: https://doc.rust-lang.org/std/os/unix/io/type.RawFd.html +/// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html +/// [`AsRawFd`]: https://doc.rust-lang.org/std/os/unix/io/trait.AsRawFd.html +/// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html +/// [`socket2`]: https://docs.rs/socket2/ +pub struct TcpSocket { + inner: mio::net::TcpSocket, +} + +impl TcpSocket { + /// Create a new socket configured for IPv4. + /// + /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv4 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v4() -> io::Result<TcpSocket> { + let inner = mio::net::TcpSocket::new_v4()?; + Ok(TcpSocket { inner }) + } + + /// Create a new socket configured for IPv6. + /// + /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`. + /// + /// # Returns + /// + /// On success, the newly created `TcpSocket` is returned. If an error is + /// encountered, it is returned instead. + /// + /// # Examples + /// + /// Create a new IPv6 socket and start listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "[::1]:8080".parse().unwrap(); + /// let socket = TcpSocket::new_v6()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(128)?; + /// # drop(listener); + /// Ok(()) + /// } + /// ``` + pub fn new_v6() -> io::Result<TcpSocket> { + let inner = mio::net::TcpSocket::new_v6()?; + Ok(TcpSocket { inner }) + } + + /// Allow the socket to bind to an in-use address. + /// + /// Behavior is platform specific. Refer to the target platform's + /// documentation for more details. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.set_reuseaddr(true)?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> { + self.inner.set_reuseaddr(reuseaddr) + } + + /// Bind the socket to the given address. + /// + /// This calls the `bind(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Bind a socket before listening. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + self.inner.bind(addr) + } + + /// Establish a TCP connection with a peer at the specified socket address. + /// + /// The `TcpSocket` is consumed. Once the connection is established, a + /// connected [`TcpStream`] is returned. If the connection fails, the + /// encountered error is returned. + /// + /// [`TcpStream`]: TcpStream + /// + /// This calls the `connect(2)` operating-system function. Behavior is + /// platform specific. Refer to the target platform's documentation for more + /// details. + /// + /// # Examples + /// + /// Connecting to a peer. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// let stream = socket.connect(addr).await?; + /// # drop(stream); + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> { + let mio = self.inner.connect(addr)?; + TcpStream::connect_mio(mio).await + } + + /// Convert the socket into a `TcpListener`. + /// + /// `backlog` defines the maximum number of pending connections are queued + /// by the operating system at any given time. Connection are removed from + /// the queue with [`TcpListener::accept`]. When the queue is full, the + /// operationg-system will start rejecting connections. + /// + /// [`TcpListener::accept`]: TcpListener::accept + /// + /// This calls the `listen(2)` operating-system function, marking the socket + /// as a passive socket. Behavior is platform specific. Refer to the target + /// platform's documentation for more details. + /// + /// # Examples + /// + /// Create a `TcpListener`. + /// + /// ```no_run + /// use tokio::net::TcpSocket; + /// + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let addr = "127.0.0.1:8080".parse().unwrap(); + /// + /// let socket = TcpSocket::new_v4()?; + /// socket.bind(addr)?; + /// + /// let listener = socket.listen(1024)?; + /// # drop(listener); + /// + /// Ok(()) + /// } + /// ``` + pub fn listen(self, backlog: u32) -> io::Result<TcpListener> { + let mio = self.inner.listen(backlog)?; + TcpListener::new(mio) + } +} + +impl fmt::Debug for TcpSocket { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(fmt) + } +} + +#[cfg(unix)] +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(unix)] +impl FromRawFd for TcpSocket { + /// Converts a `RawFd` to a `TcpSocket`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket { + let inner = mio::net::TcpSocket::from_raw_fd(fd); + TcpSocket { inner } + } +} + +#[cfg(windows)] +impl IntoRawSocket for TcpSocket { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_raw_socket() + } +} + +#[cfg(windows)] +impl AsRawSocket for TcpSocket { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +impl FromRawSocket for TcpSocket { + /// Converts a `RawSocket` to a `TcpStream`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket { + let inner = mio::net::TcpSocket::from_raw_socket(socket); + TcpSocket { inner } + } +} diff --git a/src/net/tcp/split.rs b/src/net/tcp/split.rs index 469056a..9a257f8 100644 --- a/src/net/tcp/split.rs +++ b/src/net/tcp/split.rs @@ -9,17 +9,15 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; -use bytes::Buf; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; -/// Read half of a [`TcpStream`], created by [`split`]. +/// Borrowed read half of a [`TcpStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the /// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. @@ -31,12 +29,12 @@ use std::task::{Context, Poll}; #[derive(Debug)] pub struct ReadHalf<'a>(&'a TcpStream); -/// Write half of a [`TcpStream`], created by [`split`]. +/// Borrowed write half of a [`TcpStream`], created by [`split`]. /// /// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will /// shut down the TCP stream in the write direction. /// -/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found +/// Writing to an `WriteHalf` is usually done using the convenience methods found /// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. /// /// [`TcpStream`]: TcpStream @@ -83,7 +81,7 @@ impl ReadHalf<'_> { /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { - self.0.poll_peek2(cx, buf) + self.0.poll_peek(cx, buf) } /// Receives data on the socket from the remote address to which it is @@ -131,15 +129,11 @@ impl ReadHalf<'_> { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.0.poll_read_priv(cx, buf) } } @@ -153,14 +147,6 @@ impl AsyncWrite for WriteHalf<'_> { self.0.poll_write_priv(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.0.poll_write_buf_priv(cx, buf) - } - #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { // tcp flush is a no-op diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index 3f6ee33..4b4e263 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -9,12 +9,10 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::TcpStream; -use bytes::Buf; use std::error::Error; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::sync::Arc; @@ -37,10 +35,9 @@ pub struct OwnedReadHalf { /// Owned write half of a [`TcpStream`], created by [`into_split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will -/// shut down the TCP stream in the write direction. -/// -/// Dropping the write half will shutdown the write half of the TCP stream. +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will +/// shut down the TCP stream in the write direction. Dropping the write half +/// will also shut down the write half of the TCP stream. /// /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found /// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. @@ -77,13 +74,13 @@ pub(crate) fn reunite( write.forget(); // This unwrap cannot fail as the api does not allow creating more than two Arcs, // and we just dropped the other half. - Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc")) + Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite")) } else { Err(ReuniteError(read, write)) } } -/// Error indicating two halves were not from the same socket, and thus could +/// Error indicating that two halves were not from the same socket, and thus could /// not be reunited. #[derive(Debug)] pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); @@ -139,7 +136,7 @@ impl OwnedReadHalf { /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { - self.inner.poll_peek2(cx, buf) + self.inner.poll_peek(cx, buf) } /// Receives data on the socket from the remote address to which it is @@ -187,15 +184,11 @@ impl OwnedReadHalf { } impl AsyncRead for OwnedReadHalf { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.inner.poll_read_priv(cx, buf) } } @@ -210,7 +203,9 @@ impl OwnedWriteHalf { reunite(other, self) } - /// Drop the write half, but don't issue a TCP shutdown. + /// Destroy the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. pub fn forget(mut self) { self.shutdown_on_drop = false; drop(self); @@ -234,14 +229,6 @@ impl AsyncWrite for OwnedWriteHalf { self.inner.poll_write_priv(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.inner.poll_write_buf_priv(cx, buf) - } - #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { // tcp flush is a no-op @@ -250,7 +237,11 @@ impl AsyncWrite for OwnedWriteHalf { // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - self.inner.shutdown(Shutdown::Write).into() + let res = self.inner.shutdown(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() } } diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index cc81e11..f90e9a3 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -1,21 +1,17 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::tcp::split::{split, ReadHalf, WriteHalf}; use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; -use crate::net::ToSocketAddrs; +use crate::net::{to_socket_addrs, ToSocketAddrs}; -use bytes::Buf; -use iovec::IoVec; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; -use std::net::{self, Shutdown, SocketAddr}; +use std::net::{Shutdown, SocketAddr}; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::Duration; -cfg_tcp! { +cfg_net! { /// A TCP stream between a local and a remote socket. /// /// A TCP stream can either be created by connecting to an endpoint, via the @@ -26,8 +22,8 @@ cfg_tcp! { /// traits. Examples import these traits through [the prelude]. /// /// [`connect`]: method@TcpStream::connect - /// [accepting]: method@super::TcpListener::accept - /// [listener]: struct@super::TcpListener + /// [accepting]: method@crate::net::TcpListener::accept + /// [listener]: struct@crate::net::TcpListener /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt /// [the prelude]: crate::prelude @@ -65,7 +61,7 @@ impl TcpStream { /// /// `addr` is an address of the remote host. Anything which implements the /// [`ToSocketAddrs`] trait can be supplied as the address. Note that - /// strings only implement this trait when the **`dns`** feature is enabled, + /// strings only implement this trait when the **`net`** feature is enabled, /// as strings may contain domain names that need to be resolved. /// /// If `addr` yields multiple addresses, connect will be attempted with each @@ -94,32 +90,12 @@ impl TcpStream { /// } /// ``` /// - /// Without the `dns` feature: - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// use tokio::prelude::*; - /// use std::error::Error; - /// use std::net::Ipv4Addr; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box<dyn Error>> { - /// // Connect to a peer - /// let mut stream = TcpStream::connect((Ipv4Addr::new(127, 0, 0, 1), 8080)).await?; - /// - /// // Write some data. - /// stream.write_all(b"hello world!").await?; - /// - /// Ok(()) - /// } - /// ``` - /// /// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait. /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> { - let addrs = addr.to_socket_addrs().await?; + let addrs = to_socket_addrs(addr).await?; let mut last_err = None; @@ -140,7 +116,11 @@ impl TcpStream { /// Establishes a connection to the specified `addr`. async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> { - let sys = mio::net::TcpStream::connect(&addr)?; + let sys = mio::net::TcpStream::connect(addr)?; + TcpStream::connect_mio(sys).await + } + + pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> { let stream = TcpStream::new(sys)?; // Once we've connected, wait for the stream to be writable as @@ -188,37 +168,13 @@ impl TcpStream { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> { - let io = mio::net::TcpStream::from_stream(stream)?; + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn from_std(stream: std::net::TcpStream) -> io::Result<TcpStream> { + let io = mio::net::TcpStream::from_std(stream); let io = PollEvented::new(io)?; Ok(TcpStream { io }) } - // Connects `TcpStream` asynchronously that may be built with a net2 `TcpBuilder`. - // - // This should be removed in favor of some in-crate TcpSocket builder API. - #[doc(hidden)] - pub async fn connect_std(stream: net::TcpStream, addr: &SocketAddr) -> io::Result<TcpStream> { - let io = mio::net::TcpStream::connect_stream(stream, addr)?; - let io = PollEvented::new(io)?; - let stream = TcpStream { io }; - - // Once we've connected, wait for the stream to be writable as - // that's when the actual connection has been initiated. Once we're - // writable we check for `take_socket_error` to see if the connect - // actually hit an error or not. - // - // If all that succeeded then we ship everything on up. - poll_fn(|cx| stream.io.poll_write_ready(cx)).await?; - - if let Some(e) = stream.io.get_ref().take_error()? { - return Err(e); - } - - Ok(stream) - } - /// Returns the local address that this stream is bound to. /// /// # Examples @@ -281,7 +237,7 @@ impl TcpStream { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?; + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; /// let mut buf = [0; 10]; /// /// poll_fn(|cx| { @@ -291,24 +247,17 @@ impl TcpStream { /// Ok(()) /// } /// ``` - pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { - self.poll_peek2(cx, buf) - } - - pub(super) fn poll_peek2( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().peek(buf) { - Ok(ret) => Poll::Ready(Ok(ret)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending + pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + match self.io.get_ref().peek(buf) { + Ok(ret) => return Poll::Ready(Ok(ret)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(e) => return Poll::Ready(Err(e)), } - Err(e) => Poll::Ready(Err(e)), } } @@ -349,8 +298,10 @@ impl TcpStream { /// /// [`read`]: fn@crate::io::AsyncReadExt::read /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt - pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_peek(cx, buf)).await + pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .async_io(mio::Interest::READABLE, |io| io.peek(buf)) + .await } /// Shuts down the read, write, or both halves of this connection. @@ -427,144 +378,6 @@ impl TcpStream { self.io.get_ref().set_nodelay(nodelay) } - /// Gets the value of the `SO_RCVBUF` option on this socket. - /// - /// For more information about this option, see [`set_recv_buffer_size`]. - /// - /// [`set_recv_buffer_size`]: TcpStream::set_recv_buffer_size - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// println!("{:?}", stream.recv_buffer_size()?); - /// # Ok(()) - /// # } - /// ``` - pub fn recv_buffer_size(&self) -> io::Result<usize> { - self.io.get_ref().recv_buffer_size() - } - - /// Sets the value of the `SO_RCVBUF` option on this socket. - /// - /// Changes the size of the operating system's receive buffer associated - /// with the socket. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// stream.set_recv_buffer_size(100)?; - /// # Ok(()) - /// # } - /// ``` - pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> { - self.io.get_ref().set_recv_buffer_size(size) - } - - /// Gets the value of the `SO_SNDBUF` option on this socket. - /// - /// For more information about this option, see [`set_send_buffer_size`]. - /// - /// [`set_send_buffer_size`]: TcpStream::set_send_buffer_size - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// println!("{:?}", stream.send_buffer_size()?); - /// # Ok(()) - /// # } - /// ``` - pub fn send_buffer_size(&self) -> io::Result<usize> { - self.io.get_ref().send_buffer_size() - } - - /// Sets the value of the `SO_SNDBUF` option on this socket. - /// - /// Changes the size of the operating system's send buffer associated with - /// the socket. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// stream.set_send_buffer_size(100)?; - /// # Ok(()) - /// # } - /// ``` - pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> { - self.io.get_ref().set_send_buffer_size(size) - } - - /// Returns whether keepalive messages are enabled on this socket, and if so - /// the duration of time between them. - /// - /// For more information about this option, see [`set_keepalive`]. - /// - /// [`set_keepalive`]: TcpStream::set_keepalive - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// println!("{:?}", stream.keepalive()?); - /// # Ok(()) - /// # } - /// ``` - pub fn keepalive(&self) -> io::Result<Option<Duration>> { - self.io.get_ref().keepalive() - } - - /// Sets whether keepalive messages are enabled to be sent on this socket. - /// - /// On Unix, this option will set the `SO_KEEPALIVE` as well as the - /// `TCP_KEEPALIVE` or `TCP_KEEPIDLE` option (depending on your platform). - /// On Windows, this will set the `SIO_KEEPALIVE_VALS` option. - /// - /// If `None` is specified then keepalive messages are disabled, otherwise - /// the duration specified will be the time to remain idle before sending a - /// TCP keepalive probe. - /// - /// Some platforms specify this value in seconds, so sub-second - /// specifications may be omitted. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// stream.set_keepalive(None)?; - /// # Ok(()) - /// # } - /// ``` - pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> { - self.io.get_ref().set_keepalive(keepalive) - } - /// Gets the value of the `IP_TTL` option for this socket. /// /// For more information about this option, see [`set_ttl`]. @@ -608,57 +421,9 @@ impl TcpStream { self.io.get_ref().set_ttl(ttl) } - /// Reads the linger duration for this socket by getting the `SO_LINGER` - /// option. - /// - /// For more information about this option, see [`set_linger`]. - /// - /// [`set_linger`]: TcpStream::set_linger - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// println!("{:?}", stream.linger()?); - /// # Ok(()) - /// # } - /// ``` - pub fn linger(&self) -> io::Result<Option<Duration>> { - self.io.get_ref().linger() - } - - /// Sets the linger duration of this socket by setting the `SO_LINGER` - /// option. - /// - /// This option controls the action taken when a stream has unsent messages - /// and the stream is closed. If `SO_LINGER` is set, the system - /// shall block the process until it can transmit the data or until the - /// time expires. - /// - /// If `SO_LINGER` is not specified, and the stream is closed, the system - /// handles the call in a way that allows the process to continue as quickly - /// as possible. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// - /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// stream.set_linger(None)?; - /// # Ok(()) - /// # } - /// ``` - pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> { - self.io.get_ref().set_linger(dur) - } - + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] /// Splits a `TcpStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. /// @@ -666,7 +431,7 @@ impl TcpStream { /// moved into independently spawned tasks. /// /// [`into_split`]: TcpStream::into_split() - pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { split(self) } @@ -676,10 +441,11 @@ impl TcpStream { /// Unlike [`split`], the owned halves can be moved to separate tasks, however /// this comes at the cost of a heap allocation. /// - /// **Note::** Dropping the write half will shutdown the write half of the TCP - /// stream. This is equivalent to calling `shutdown(Write)` on the `TcpStream`. + /// **Note:** Dropping the write half will shut down the write half of the TCP + /// stream. This is equivalent to calling [`shutdown(Write)`] on the `TcpStream`. /// /// [`split`]: TcpStream::split() + /// [`shutdown(Write)`]: fn@crate::net::TcpStream::shutdown pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { split_owned(self) } @@ -698,16 +464,30 @@ impl TcpStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().read(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: `TcpStream::read` will not peek at the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + match self.io.get_ref().read(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Ok(n) => { + // Safety: We trust `TcpStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(())); + } + Err(e) => return Poll::Ready(Err(e)), } - x => Poll::Ready(x), } } @@ -716,143 +496,27 @@ impl TcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().write(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), } - x => Poll::Ready(x), - } - } - - pub(super) fn poll_write_buf_priv<B: Buf>( - &self, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - use std::io::IoSlice; - - ready!(self.io.poll_write_ready(cx))?; - - // The `IoVec` (v0.1.x) type can't have a zero-length size, so create - // a dummy version from a 1-length slice which we'll overwrite with - // the `bytes_vectored` method. - static S: &[u8] = &[0]; - const MAX_BUFS: usize = 64; - - // IoSlice isn't Copy, so we must expand this manually ;_; - let mut slices: [IoSlice<'_>; MAX_BUFS] = [ - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - IoSlice::new(S), - ]; - let cnt = buf.bytes_vectored(&mut slices); - - let iovec = <&IoVec>::from(S); - let mut vecs = [iovec; MAX_BUFS]; - for i in 0..cnt { - vecs[i] = (*slices[i]).into(); - } - - match self.io.get_ref().write_bufs(&vecs[..cnt]) { - Ok(n) => { - buf.advance(n); - Poll::Ready(Ok(n)) - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - Err(e) => Poll::Ready(Err(e)), } } } -impl TryFrom<TcpStream> for mio::net::TcpStream { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. - /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: TcpStream) -> Result<Self, Self::Error> { - value.io.into_inner() - } -} - -impl TryFrom<net::TcpStream> for TcpStream { +impl TryFrom<std::net::TcpStream> for TcpStream { type Error = io::Error; /// Consumes stream, returning the tokio I/O object. /// /// This is equivalent to /// [`TcpStream::from_std(stream)`](TcpStream::from_std). - fn try_from(stream: net::TcpStream) -> Result<Self, Self::Error> { + fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> { Self::from_std(stream) } } @@ -860,15 +524,11 @@ impl TryFrom<net::TcpStream> for TcpStream { // ===== impl Read / Write ===== impl AsyncRead for TcpStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.poll_read_priv(cx, buf) } } @@ -882,14 +542,6 @@ impl AsyncWrite for TcpStream { self.poll_write_priv(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.poll_write_buf_priv(cx, buf) - } - #[inline] fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { // tcp flush is a no-op @@ -922,14 +574,12 @@ mod sys { #[cfg(windows)] mod sys { - // TODO: let's land these upstream with mio and then we can add them here. - // - // use std::os::windows::prelude::*; - // use super::TcpStream; - // - // impl AsRawHandle for TcpStream { - // fn as_raw_handle(&self) -> RawHandle { - // self.io.get_ref().as_raw_handle() - // } - // } + use super::TcpStream; + use std::os::windows::prelude::*; + + impl AsRawSocket for TcpStream { + fn as_raw_socket(&self) -> RawSocket { + self.io.get_ref().as_raw_socket() + } + } } diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs index d43121a..c9bb0f8 100644 --- a/src/net/udp/mod.rs +++ b/src/net/udp/mod.rs @@ -1,7 +1,3 @@ //! UDP utility types. pub(crate) mod socket; -pub(crate) use socket::UdpSocket; - -mod split; -pub use split::{RecvHalf, ReuniteError, SendHalf}; diff --git a/src/net/udp/socket.rs b/src/net/udp/socket.rs index 97090a2..d13e92b 100644 --- a/src/net/udp/socket.rs +++ b/src/net/udp/socket.rs @@ -1,16 +1,110 @@ -use crate::future::poll_fn; use crate::io::PollEvented; -use crate::net::udp::split::{split, RecvHalf, SendHalf}; -use crate::net::ToSocketAddrs; +use crate::net::{to_socket_addrs, ToSocketAddrs}; use std::convert::TryFrom; use std::fmt; use std::io; use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr}; -use std::task::{Context, Poll}; -cfg_udp! { +cfg_net! { /// A UDP socket + /// + /// UDP is "connectionless", unlike TCP. Meaning, regardless of what address you've bound to, a `UdpSocket` + /// is free to communicate with many different remotes. In tokio there are basically two main ways to use `UdpSocket`: + /// + /// * one to many: [`bind`](`UdpSocket::bind`) and use [`send_to`](`UdpSocket::send_to`) + /// and [`recv_from`](`UdpSocket::recv_from`) to communicate with many different addresses + /// * one to one: [`connect`](`UdpSocket::connect`) and associate with a single address, using [`send`](`UdpSocket::send`) + /// and [`recv`](`UdpSocket::recv`) to communicate only with that remote address + /// + /// `UdpSocket` can also be used concurrently to `send_to` and `recv_from` in different tasks, + /// all that's required is that you `Arc<UdpSocket>` and clone a reference for each task. + /// + /// # Streams + /// + /// If you need to listen over UDP and produce a [`Stream`](`crate::stream::Stream`), you can look + /// at [`UdpFramed`]. + /// + /// [`UdpFramed`]: https://docs.rs/tokio-util/latest/tokio_util/udp/struct.UdpFramed.html + /// + /// # Example: one to many (bind) + /// + /// Using `bind` we can create a simple echo server that sends and recv's with many different clients: + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// let mut buf = [0; 1024]; + /// loop { + /// let (len, addr) = sock.recv_from(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, addr); + /// + /// let len = sock.send_to(&buf[..len], addr).await?; + /// println!("{:?} bytes sent", len); + /// } + /// } + /// ``` + /// + /// # Example: one to one (connect) + /// + /// Or using `connect` we can echo with a single remote address using `send` and `recv`: + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// + /// let remote_addr = "127.0.0.1:59611"; + /// sock.connect(remote_addr).await?; + /// let mut buf = [0; 1024]; + /// loop { + /// let len = sock.recv(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, remote_addr); + /// + /// let len = sock.send(&buf[..len]).await?; + /// println!("{:?} bytes sent", len); + /// } + /// } + /// ``` + /// + /// # Example: Sending/Receiving concurrently + /// + /// Because `send_to` and `recv_from` take `&self`. It's perfectly alright to `Arc<UdpSocket>` + /// and share the references to multiple tasks, in order to send/receive concurrently. Here is + /// a similar "echo" example but that supports concurrent sending/receiving: + /// + /// ```no_run + /// use tokio::{net::UdpSocket, sync::mpsc}; + /// use std::{io, net::SocketAddr, sync::Arc}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// let r = Arc::new(sock); + /// let s = r.clone(); + /// let (tx, mut rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(1_000); + /// + /// tokio::spawn(async move { + /// while let Some((bytes, addr)) = rx.recv().await { + /// let len = s.send_to(&bytes, &addr).await.unwrap(); + /// println!("{:?} bytes sent", len); + /// } + /// }); + /// + /// let mut buf = [0; 1024]; + /// loop { + /// let (len, addr) = r.recv_from(&mut buf).await?; + /// println!("{:?} bytes received from {:?}", len, addr); + /// tx.send((buf[..len].to_vec(), addr)).await.unwrap(); + /// } + /// } + /// ``` + /// pub struct UdpSocket { io: PollEvented<mio::net::UdpSocket>, } @@ -19,8 +113,23 @@ cfg_udp! { impl UdpSocket { /// This function will create a new UDP socket and attempt to bind it to /// the `addr` provided. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080").await?; + /// // use `sock` + /// # let _ = sock; + /// Ok(()) + /// } + /// ``` pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> { - let addrs = addr.to_socket_addrs().await?; + let addrs = to_socket_addrs(addr).await?; let mut last_err = None; for addr in addrs { @@ -39,7 +148,7 @@ impl UdpSocket { } fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> { - let sys = mio::net::UdpSocket::bind(&addr)?; + let sys = mio::net::UdpSocket::bind(addr)?; UdpSocket::new(sys) } @@ -64,21 +173,45 @@ impl UdpSocket { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap(); + /// let std_sock = std::net::UdpSocket::bind(addr)?; + /// let sock = UdpSocket::from_std(std_sock)?; + /// // use `sock` + /// # Ok(()) + /// # } + /// ``` pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> { - let io = mio::net::UdpSocket::from_socket(socket)?; - let io = PollEvented::new(io)?; - Ok(UdpSocket { io }) - } - - /// Splits the `UdpSocket` into a receive half and a send half. The two parts - /// can be used to receive and send datagrams concurrently, even from two - /// different tasks. - pub fn split(self) -> (RecvHalf, SendHalf) { - split(self) + let io = mio::net::UdpSocket::from_std(socket); + UdpSocket::new(io) } /// Returns the local address that this socket is bound to. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap(); + /// let sock = UdpSocket::bind(addr).await?; + /// // the address the socket is bound to + /// let local_addr = sock.local_addr()?; + /// # Ok(()) + /// # } + /// ``` pub fn local_addr(&self) -> io::Result<SocketAddr> { self.io.get_ref().local_addr() } @@ -86,8 +219,29 @@ impl UdpSocket { /// Connects the UDP socket setting the default destination for send() and /// limiting packets that are read via recv from the address specified in /// `addr`. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// + /// let remote_addr = "127.0.0.1:59600".parse::<SocketAddr>().unwrap(); + /// sock.connect(remote_addr).await?; + /// let mut buf = [0u8; 32]; + /// // recv from remote_addr + /// let len = sock.recv(&mut buf).await?; + /// // send to remote_addr + /// let _len = sock.send(&buf[..len]).await?; + /// # Ok(()) + /// # } + /// ``` pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> { - let addrs = addr.to_socket_addrs().await?; + let addrs = to_socket_addrs(addr).await?; let mut last_err = None; for addr in addrs { @@ -112,31 +266,24 @@ impl UdpSocket { /// will resolve to an error if the socket is not connected. /// /// [`connect`]: method@Self::connect - pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_send(cx, buf)).await - } - - // Poll IO functions that takes `&self` are provided for the split API. - // - // They are not public because (taken from the doc of `PollEvented`): - // - // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the - // caller must ensure that there are at most two tasks that use a - // `PollEvented` instance concurrently. One for reading and one for writing. - // While violating this requirement is "safe" from a Rust memory model point - // of view, it will result in unexpected behavior in the form of lost - // notifications and tasks hanging. - #[doc(hidden)] - pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().send(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - x => Poll::Ready(x), - } + pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .async_io(mio::Interest::WRITABLE, |sock| sock.send(buf)) + .await + } + + /// Try to send data on the socket to the remote address to which it is + /// connected. + /// + /// # Returns + /// + /// If successfull, the number of bytes sent is returned. Users + /// should ensure that when the remote cannot receive, the + /// [`ErrorKind::WouldBlock`] is properly handled. + /// + /// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> { + self.io.get_ref().send(buf) } /// Returns a future that receives a single datagram message on the socket from @@ -151,21 +298,10 @@ impl UdpSocket { /// will fail if the socket is not connected. /// /// [`connect`]: method@Self::connect - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_recv(cx, buf)).await - } - - #[doc(hidden)] - pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().recv(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending - } - x => Poll::Ready(x), - } + pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .async_io(mio::Interest::READABLE, |sock| sock.recv(buf)) + .await } /// Returns a future that sends data on the socket to the given address. @@ -173,11 +309,27 @@ impl UdpSocket { /// /// The future will resolve to an error if the IP version of the socket does /// not match that of `target`. - pub async fn send_to<A: ToSocketAddrs>(&mut self, buf: &[u8], target: A) -> io::Result<usize> { - let mut addrs = target.to_socket_addrs().await?; + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// let buf = b"hello world"; + /// let remote_addr = "127.0.0.1:58000".parse::<SocketAddr>().unwrap(); + /// let _len = sock.send_to(&buf[..], remote_addr).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> io::Result<usize> { + let mut addrs = to_socket_addrs(target).await?; match addrs.next() { - Some(target) => poll_fn(|cx| self.poll_send_to(cx, buf, &target)).await, + Some(target) => self.send_to_addr(buf, target).await, None => Err(io::Error::new( io::ErrorKind::InvalidInput, "no addresses to send data to", @@ -185,23 +337,42 @@ impl UdpSocket { } } - // TODO: Public or not? - #[doc(hidden)] - pub fn poll_send_to( - &self, - cx: &mut Context<'_>, - buf: &[u8], - target: &SocketAddr, - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().send_to(buf, target) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - x => Poll::Ready(x), - } + /// Try to send data on the socket to the given address, but if the send is blocked + /// this will return right away. + /// + /// # Returns + /// + /// If successfull, returns the number of bytes sent + /// + /// Users should ensure that when the remote cannot receive, the + /// [`ErrorKind::WouldBlock`] is properly handled. An error can also occur + /// if the IP version of the socket does not match that of `target`. + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// let buf = b"hello world"; + /// let remote_addr = "127.0.0.1:58000".parse::<SocketAddr>().unwrap(); + /// let _len = sock.try_send_to(&buf[..], remote_addr)?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> { + self.io.get_ref().send_to(buf, target) + } + + async fn send_to_addr(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> { + self.io + .async_io(mio::Interest::WRITABLE, |sock| sock.send_to(buf, target)) + .await } /// Returns a future that receives a single datagram on the socket. On success, @@ -210,25 +381,26 @@ impl UdpSocket { /// The function must be called with valid byte array `buf` of sufficient size /// to hold the message bytes. If a message is too long to fit in the supplied /// buffer, excess bytes may be discarded. - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from(cx, buf)).await - } - - #[doc(hidden)] - pub fn poll_recv_from( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<(usize, SocketAddr), io::Error>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().recv_from(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending - } - x => Poll::Ready(x), - } + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::{io, net::SocketAddr}; + /// + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?; + /// let mut buf = [0u8; 32]; + /// let (len, addr) = sock.recv_from(&mut buf).await?; + /// println!("received {:?} bytes from {:?}", len, addr); + /// # Ok(()) + /// # } + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.io + .async_io(mio::Interest::READABLE, |sock| sock.recv_from(buf)) + .await } /// Gets the value of the `SO_BROADCAST` option for this socket. @@ -315,6 +487,20 @@ impl UdpSocket { /// For more information about this option, see [`set_ttl`]. /// /// [`set_ttl`]: method@Self::set_ttl + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::io; + /// + /// # async fn dox() -> io::Result<()> { + /// let sock = UdpSocket::bind("127.0.0.1:8080").await?; + /// + /// println!("{:?}", sock.ttl()?); + /// # Ok(()) + /// # } + /// ``` pub fn ttl(&self) -> io::Result<u32> { self.io.get_ref().ttl() } @@ -323,6 +509,20 @@ impl UdpSocket { /// /// This value sets the time-to-live field that is used in every packet sent /// from this socket. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UdpSocket; + /// # use std::io; + /// + /// # async fn dox() -> io::Result<()> { + /// let sock = UdpSocket::bind("127.0.0.1:8080").await?; + /// sock.set_ttl(60)?; + /// + /// # Ok(()) + /// # } + /// ``` pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { self.io.get_ref().set_ttl(ttl) } @@ -366,28 +566,14 @@ impl UdpSocket { } } -impl TryFrom<UdpSocket> for mio::net::UdpSocket { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. - /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: UdpSocket) -> Result<Self, Self::Error> { - value.io.into_inner() - } -} - -impl TryFrom<net::UdpSocket> for UdpSocket { +impl TryFrom<std::net::UdpSocket> for UdpSocket { type Error = io::Error; /// Consumes stream, returning the tokio I/O object. /// /// This is equivalent to /// [`UdpSocket::from_std(stream)`](UdpSocket::from_std). - fn try_from(stream: net::UdpSocket) -> Result<Self, Self::Error> { + fn try_from(stream: std::net::UdpSocket) -> Result<Self, Self::Error> { Self::from_std(stream) } } @@ -412,14 +598,12 @@ mod sys { #[cfg(windows)] mod sys { - // TODO: let's land these upstream with mio and then we can add them here. - // - // use std::os::windows::prelude::*; - // use super::UdpSocket; - // - // impl AsRawHandle for UdpSocket { - // fn as_raw_handle(&self) -> RawHandle { - // self.io.get_ref().as_raw_handle() - // } - // } + use super::UdpSocket; + use std::os::windows::prelude::*; + + impl AsRawSocket for UdpSocket { + fn as_raw_socket(&self) -> RawSocket { + self.io.get_ref().as_raw_socket() + } + } } diff --git a/src/net/udp/split.rs b/src/net/udp/split.rs deleted file mode 100644 index e8d434a..0000000 --- a/src/net/udp/split.rs +++ /dev/null @@ -1,148 +0,0 @@ -//! [`UdpSocket`](crate::net::UdpSocket) split support. -//! -//! The [`split`](method@crate::net::UdpSocket::split) method splits a -//! `UdpSocket` into a receive half and a send half, which can be used to -//! receive and send datagrams concurrently, even from two different tasks. -//! -//! The halves provide access to the underlying socket, implementing -//! `AsRef<UdpSocket>`. This allows you to call `UdpSocket` methods that takes -//! `&self`, e.g., to get local address, to get and set socket options, to join -//! or leave multicast groups, etc. -//! -//! The halves can be reunited to the original socket with their `reunite` -//! methods. - -use crate::future::poll_fn; -use crate::net::udp::UdpSocket; - -use std::error::Error; -use std::fmt; -use std::io; -use std::net::SocketAddr; -use std::sync::Arc; - -/// The send half after [`split`](super::UdpSocket::split). -/// -/// Use [`send_to`](method@Self::send_to) or [`send`](method@Self::send) to send -/// datagrams. -#[derive(Debug)] -pub struct SendHalf(Arc<UdpSocket>); - -/// The recv half after [`split`](super::UdpSocket::split). -/// -/// Use [`recv_from`](method@Self::recv_from) or [`recv`](method@Self::recv) to receive -/// datagrams. -#[derive(Debug)] -pub struct RecvHalf(Arc<UdpSocket>); - -pub(crate) fn split(socket: UdpSocket) -> (RecvHalf, SendHalf) { - let shared = Arc::new(socket); - let send = shared.clone(); - let recv = shared; - (RecvHalf(recv), SendHalf(send)) -} - -/// Error indicating two halves were not from the same socket, and thus could -/// not be `reunite`d. -#[derive(Debug)] -pub struct ReuniteError(pub SendHalf, pub RecvHalf); - -impl fmt::Display for ReuniteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "tried to reunite halves that are not from the same socket" - ) - } -} - -impl Error for ReuniteError {} - -fn reunite(s: SendHalf, r: RecvHalf) -> Result<UdpSocket, ReuniteError> { - if Arc::ptr_eq(&s.0, &r.0) { - drop(r); - // Only two instances of the `Arc` are ever created, one for the - // receiver and one for the sender, and those `Arc`s are never exposed - // externally. And so when we drop one here, the other one must be the - // only remaining one. - Ok(Arc::try_unwrap(s.0).expect("udp: try_unwrap failed in reunite")) - } else { - Err(ReuniteError(s, r)) - } -} - -impl RecvHalf { - /// Attempts to put the two "halves" of a `UdpSocket` back together and - /// recover the original socket. Succeeds only if the two "halves" - /// originated from the same call to `UdpSocket::split`. - pub fn reunite(self, other: SendHalf) -> Result<UdpSocket, ReuniteError> { - reunite(other, self) - } - - /// Returns a future that receives a single datagram on the socket. On success, - /// the future resolves to the number of bytes read and the origin. - /// - /// The function must be called with valid byte array `buf` of sufficient size - /// to hold the message bytes. If a message is too long to fit in the supplied - /// buffer, excess bytes may be discarded. - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - poll_fn(|cx| self.0.poll_recv_from(cx, buf)).await - } - - /// Returns a future that receives a single datagram message on the socket from - /// the remote address to which it is connected. On success, the future will resolve - /// to the number of bytes read. - /// - /// The function must be called with valid byte array `buf` of sufficient size to - /// hold the message bytes. If a message is too long to fit in the supplied buffer, - /// excess bytes may be discarded. - /// - /// The [`connect`] method will connect this socket to a remote address. The future - /// will fail if the socket is not connected. - /// - /// [`connect`]: super::UdpSocket::connect - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.0.poll_recv(cx, buf)).await - } -} - -impl SendHalf { - /// Attempts to put the two "halves" of a `UdpSocket` back together and - /// recover the original socket. Succeeds only if the two "halves" - /// originated from the same call to `UdpSocket::split`. - pub fn reunite(self, other: RecvHalf) -> Result<UdpSocket, ReuniteError> { - reunite(self, other) - } - - /// Returns a future that sends data on the socket to the given address. - /// On success, the future will resolve to the number of bytes written. - /// - /// The future will resolve to an error if the IP version of the socket does - /// not match that of `target`. - pub async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> { - poll_fn(|cx| self.0.poll_send_to(cx, buf, target)).await - } - - /// Returns a future that sends data on the socket to the remote address to which it is connected. - /// On success, the future will resolve to the number of bytes written. - /// - /// The [`connect`] method will connect this socket to a remote address. The future - /// will resolve to an error if the socket is not connected. - /// - /// [`connect`]: super::UdpSocket::connect - pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { - poll_fn(|cx| self.0.poll_send(cx, buf)).await - } -} - -impl AsRef<UdpSocket> for SendHalf { - fn as_ref(&self) -> &UdpSocket { - &self.0 - } -} - -impl AsRef<UdpSocket> for RecvHalf { - fn as_ref(&self) -> &UdpSocket { - &self.0 - } -} diff --git a/src/net/unix/datagram.rs b/src/net/unix/datagram.rs deleted file mode 100644 index ff0f424..0000000 --- a/src/net/unix/datagram.rs +++ /dev/null @@ -1,242 +0,0 @@ -use crate::future::poll_fn; -use crate::io::PollEvented; - -use std::convert::TryFrom; -use std::fmt; -use std::io; -use std::net::Shutdown; -use std::os::unix::io::{AsRawFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; -use std::path::Path; -use std::task::{Context, Poll}; - -cfg_uds! { - /// An I/O object representing a Unix datagram socket. - pub struct UnixDatagram { - io: PollEvented<mio_uds::UnixDatagram>, - } -} - -impl UnixDatagram { - /// Creates a new `UnixDatagram` bound to the specified path. - pub fn bind<P>(path: P) -> io::Result<UnixDatagram> - where - P: AsRef<Path>, - { - let socket = mio_uds::UnixDatagram::bind(path)?; - UnixDatagram::new(socket) - } - - /// Creates an unnamed pair of connected sockets. - /// - /// This function will create a pair of interconnected Unix sockets for - /// communicating back and forth between one another. Each socket will - /// be associated with the default event loop's handle. - pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> { - let (a, b) = mio_uds::UnixDatagram::pair()?; - let a = UnixDatagram::new(a)?; - let b = UnixDatagram::new(b)?; - - Ok((a, b)) - } - - /// Consumes a `UnixDatagram` in the standard library and returns a - /// nonblocking `UnixDatagram` from this crate. - /// - /// The returned datagram will be associated with the given event loop - /// specified by `handle` and is ready to perform I/O. - /// - /// # Panics - /// - /// This function panics if thread-local runtime is not set. - /// - /// The runtime is usually set implicitly when this function is called - /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> { - let socket = mio_uds::UnixDatagram::from_datagram(datagram)?; - let io = PollEvented::new(socket)?; - Ok(UnixDatagram { io }) - } - - fn new(socket: mio_uds::UnixDatagram) -> io::Result<UnixDatagram> { - let io = PollEvented::new(socket)?; - Ok(UnixDatagram { io }) - } - - /// Creates a new `UnixDatagram` which is not bound to any address. - pub fn unbound() -> io::Result<UnixDatagram> { - let socket = mio_uds::UnixDatagram::unbound()?; - UnixDatagram::new(socket) - } - - /// Connects the socket to the specified address. - /// - /// The `send` method may be used to send data to the specified address. - /// `recv` and `recv_from` will only receive data from that address. - pub fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> { - self.io.get_ref().connect(path) - } - - /// Sends data on the socket to the socket's peer. - pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_send_priv(cx, buf)).await - } - - // Poll IO functions that takes `&self` are provided for the split API. - // - // They are not public because (taken from the doc of `PollEvented`): - // - // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the - // caller must ensure that there are at most two tasks that use a - // `PollEvented` instance concurrently. One for reading and one for writing. - // While violating this requirement is "safe" from a Rust memory model point - // of view, it will result in unexpected behavior in the form of lost - // notifications and tasks hanging. - pub(crate) fn poll_send_priv( - &self, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().send(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - x => Poll::Ready(x), - } - } - - /// Receives data from the socket. - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_recv_priv(cx, buf)).await - } - - pub(crate) fn poll_recv_priv( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().recv(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending - } - x => Poll::Ready(x), - } - } - - /// Sends data on the socket to the specified address. - pub async fn send_to<P>(&mut self, buf: &[u8], target: P) -> io::Result<usize> - where - P: AsRef<Path> + Unpin, - { - poll_fn(|cx| self.poll_send_to_priv(cx, buf, target.as_ref())).await - } - - pub(crate) fn poll_send_to_priv( - &self, - cx: &mut Context<'_>, - buf: &[u8], - target: &Path, - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().send_to(buf, target) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending - } - x => Poll::Ready(x), - } - } - - /// Receives data from the socket. - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from_priv(cx, buf)).await - } - - pub(crate) fn poll_recv_from_priv( - &self, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<(usize, SocketAddr), io::Error>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().recv_from(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending - } - x => Poll::Ready(x), - } - } - - /// Returns the local address that this socket is bound to. - pub fn local_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().local_addr() - } - - /// Returns the address of this socket's peer. - /// - /// The `connect` method will connect the socket to a peer. - pub fn peer_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().peer_addr() - } - - /// Returns the value of the `SO_ERROR` option. - pub fn take_error(&self) -> io::Result<Option<io::Error>> { - self.io.get_ref().take_error() - } - - /// Shuts down the read, write, or both halves of this connection. - /// - /// This function will cause all pending and future I/O calls on the - /// specified portions to immediately return with an appropriate value - /// (see the documentation of `Shutdown`). - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.io.get_ref().shutdown(how) - } -} - -impl TryFrom<UnixDatagram> for mio_uds::UnixDatagram { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. - /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: UnixDatagram) -> Result<Self, Self::Error> { - value.io.into_inner() - } -} - -impl TryFrom<net::UnixDatagram> for UnixDatagram { - type Error = io::Error; - - /// Consumes stream, returning the tokio I/O object. - /// - /// This is equivalent to - /// [`UnixDatagram::from_std(stream)`](UnixDatagram::from_std). - fn try_from(stream: net::UnixDatagram) -> Result<Self, Self::Error> { - Self::from_std(stream) - } -} - -impl fmt::Debug for UnixDatagram { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.io.get_ref().fmt(f) - } -} - -impl AsRawFd for UnixDatagram { - fn as_raw_fd(&self) -> RawFd { - self.io.get_ref().as_raw_fd() - } -} diff --git a/src/net/unix/datagram/mod.rs b/src/net/unix/datagram/mod.rs new file mode 100644 index 0000000..6268b4a --- /dev/null +++ b/src/net/unix/datagram/mod.rs @@ -0,0 +1,3 @@ +//! Unix datagram types. + +pub(crate) mod socket; diff --git a/src/net/unix/datagram/socket.rs b/src/net/unix/datagram/socket.rs new file mode 100644 index 0000000..3ae66d1 --- /dev/null +++ b/src/net/unix/datagram/socket.rs @@ -0,0 +1,731 @@ +use crate::io::PollEvented; +use crate::net::unix::SocketAddr; + +use std::convert::TryFrom; +use std::fmt; +use std::io; +use std::net::Shutdown; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net; +use std::path::Path; + +cfg_net_unix! { + /// An I/O object representing a Unix datagram socket. + /// + /// A socket can be either named (associated with a filesystem path) or + /// unnamed. + /// + /// **Note:** named sockets are persisted even after the object is dropped + /// and the program has exited, and cannot be reconnected. It is advised + /// that you either check for and unlink the existing socket if it exists, + /// or use a temporary file that is guaranteed to not already exist. + /// + /// # Examples + /// Using named sockets, associated with a filesystem path: + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// Using unnamed sockets, created as a pair + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub struct UnixDatagram { + io: PollEvented<mio::net::UnixDatagram>, + } +} + +impl UnixDatagram { + /// Creates a new `UnixDatagram` bound to the specified path. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind the socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let socket = UnixDatagram::bind(&socket_path)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn bind<P>(path: P) -> io::Result<UnixDatagram> + where + P: AsRef<Path>, + { + let socket = mio::net::UnixDatagram::bind(path)?; + UnixDatagram::new(socket) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// This function will create a pair of interconnected Unix sockets for + /// communicating back and forth between one another. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hail eris"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> { + let (a, b) = mio::net::UnixDatagram::pair()?; + let a = UnixDatagram::new(a)?; + let b = UnixDatagram::new(b)?; + + Ok((a, b)) + } + + /// Consumes a `UnixDatagram` in the standard library and returns a + /// nonblocking `UnixDatagram` from this crate. + /// + /// The returned datagram will be associated with the given event loop + /// specified by `handle` and is ready to perform I/O. + /// + /// # Panics + /// + /// This function panics if thread-local runtime is not set. + /// + /// The runtime is usually set implicitly when this function is called + /// from a future driven by a Tokio runtime, otherwise runtime can be set + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use std::os::unix::net::UnixDatagram as StdUDS; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind the socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let std_socket = StdUDS::bind(&socket_path)?; + /// let tokio_socket = UnixDatagram::from_std(std_socket)?; + /// + /// # Ok(()) + /// # } + /// ``` + pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> { + let socket = mio::net::UnixDatagram::from_std(datagram); + let io = PollEvented::new(socket)?; + Ok(UnixDatagram { io }) + } + + fn new(socket: mio::net::UnixDatagram) -> io::Result<UnixDatagram> { + let io = PollEvented::new(socket)?; + Ok(UnixDatagram { io }) + } + + /// Creates a new `UnixDatagram` which is not bound to any address. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Send to the bound socket + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn unbound() -> io::Result<UnixDatagram> { + let socket = mio::net::UnixDatagram::unbound()?; + UnixDatagram::new(socket) + } + + /// Connects the socket to the specified address. + /// + /// The `send` method may be used to send data to the specified address. + /// `recv` and `recv_from` will only receive data from that address. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Connect to the bound socket + /// tx.connect(&rx_path)?; + /// + /// // Send to the bound socket + /// let bytes = b"hello world"; + /// tx.send(bytes).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> { + self.io.get_ref().connect(path) + } + + /// Sends data on the socket to the socket's peer. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .async_io(mio::Interest::WRITABLE, |sock| sock.send(buf)) + .await + } + + /// Try to send a datagram to the peer without waiting. + /// + /// # Examples + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// use tokio::net::UnixDatagram; + /// + /// let bytes = b"bytes"; + /// // We use a socket pair so that they are assigned + /// // each other as a peer. + /// let (first, second) = UnixDatagram::pair()?; + /// + /// let size = first.try_send(bytes)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let size = second.try_recv(&mut buffer)?; + /// + /// let dgram = &buffer[..size]; + /// assert_eq!(dgram, bytes); + /// # Ok(()) + /// # } + /// ``` + pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> { + self.io.get_ref().send(buf) + } + + /// Try to send a datagram to the peer without waiting. + /// + /// # Examples + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// let bytes = b"bytes"; + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir().unwrap(); + /// + /// let server_path = tmp.path().join("server"); + /// let server = UnixDatagram::bind(&server_path)?; + /// + /// let client_path = tmp.path().join("client"); + /// let client = UnixDatagram::bind(&client_path)?; + /// + /// let size = client.try_send_to(bytes, &server_path)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let (size, addr) = server.try_recv_from(&mut buffer)?; + /// + /// let dgram = &buffer[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &client_path); + /// # Ok(()) + /// # } + /// ``` + pub fn try_send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize> + where + P: AsRef<Path>, + { + self.io.get_ref().send_to(buf, target) + } + + /// Receives data from the socket. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// // Since the sockets are paired, the paired send/recv + /// // functions can be used + /// let bytes = b"hello world"; + /// sock1.send(bytes).await?; + /// + /// let mut buff = vec![0u8; 24]; + /// let size = sock2.recv(&mut buff).await?; + /// + /// let dgram = &buff[..size]; + /// assert_eq!(dgram, bytes); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .async_io(mio::Interest::READABLE, |sock| sock.recv(buf)) + .await + } + + /// Try to receive a datagram from the peer without waiting. + /// + /// # Examples + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// use tokio::net::UnixDatagram; + /// + /// let bytes = b"bytes"; + /// // We use a socket pair so that they are assigned + /// // each other as a peer. + /// let (first, second) = UnixDatagram::pair()?; + /// + /// let size = first.try_send(bytes)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let size = second.try_recv(&mut buffer)?; + /// + /// let dgram = &buffer[..size]; + /// assert_eq!(dgram, bytes); + /// # Ok(()) + /// # } + /// ``` + pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io.get_ref().recv(buf) + } + + /// Sends data on the socket to the specified address. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize> + where + P: AsRef<Path>, + { + self.io + .async_io(mio::Interest::WRITABLE, |sock| { + sock.send_to(buf, target.as_ref()) + }) + .await + } + + /// Receives data from the socket. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind each socket to a filesystem path + /// let tx_path = tmp.path().join("tx"); + /// let tx = UnixDatagram::bind(&tx_path)?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// let bytes = b"hello world"; + /// tx.send_to(bytes, &rx_path).await?; + /// + /// let mut buf = vec![0u8; 24]; + /// let (size, addr) = rx.recv_from(&mut buf).await?; + /// + /// let dgram = &buf[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &tx_path); + /// + /// # Ok(()) + /// # } + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self + .io + .async_io(mio::Interest::READABLE, |sock| sock.recv_from(buf)) + .await?; + + Ok((n, SocketAddr(addr))) + } + + /// Try to receive data from the socket without waiting. + /// + /// # Examples + /// ``` + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// let bytes = b"bytes"; + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir().unwrap(); + /// + /// let server_path = tmp.path().join("server"); + /// let server = UnixDatagram::bind(&server_path)?; + /// + /// let client_path = tmp.path().join("client"); + /// let client = UnixDatagram::bind(&client_path)?; + /// + /// let size = client.try_send_to(bytes, &server_path)?; + /// assert_eq!(size, bytes.len()); + /// + /// let mut buffer = vec![0u8; 24]; + /// let (size, addr) = server.try_recv_from(&mut buffer)?; + /// + /// let dgram = &buffer[..size]; + /// assert_eq!(dgram, bytes); + /// assert_eq!(addr.as_pathname().unwrap(), &client_path); + /// # Ok(()) + /// # } + /// ``` + pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + let (n, addr) = self.io.get_ref().recv_from(buf)?; + Ok((n, SocketAddr(addr))) + } + + /// Returns the local address that this socket is bound to. + /// + /// # Examples + /// For a socket bound to a local path + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // We use a temporary directory so that the socket + /// // files left by the bound sockets will get cleaned up. + /// let tmp = tempdir()?; + /// + /// // Bind socket to a filesystem path + /// let socket_path = tmp.path().join("socket"); + /// let socket = UnixDatagram::bind(&socket_path)?; + /// + /// assert_eq!(socket.local_addr()?.as_pathname().unwrap(), &socket_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// For an unbound socket + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create an unbound socket + /// let socket = UnixDatagram::unbound()?; + /// + /// assert!(socket.local_addr()?.is_unnamed()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.io.get_ref().local_addr().map(SocketAddr) + } + + /// Returns the address of this socket's peer. + /// + /// The `connect` method will connect the socket to a peer. + /// + /// # Examples + /// For a peer with a local path + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use tempfile::tempdir; + /// + /// // Create an unbound socket + /// let tx = UnixDatagram::unbound()?; + /// + /// // Create another, bound socket + /// let tmp = tempdir()?; + /// let rx_path = tmp.path().join("rx"); + /// let rx = UnixDatagram::bind(&rx_path)?; + /// + /// // Connect to the bound socket + /// tx.connect(&rx_path)?; + /// + /// assert_eq!(tx.peer_addr()?.as_pathname().unwrap(), &rx_path); + /// + /// # Ok(()) + /// # } + /// ``` + /// + /// For an unbound peer + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create the pair of sockets + /// let (sock1, sock2) = UnixDatagram::pair()?; + /// + /// assert!(sock1.peer_addr()?.is_unnamed()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.io.get_ref().peer_addr().map(SocketAddr) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// + /// // Create an unbound socket + /// let socket = UnixDatagram::unbound()?; + /// + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// + /// # Ok(()) + /// # } + /// ``` + pub fn take_error(&self) -> io::Result<Option<io::Error>> { + self.io.get_ref().take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + /// + /// # Examples + /// ``` + /// # use std::error::Error; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box<dyn Error>> { + /// use tokio::net::UnixDatagram; + /// use std::net::Shutdown; + /// + /// // Create an unbound socket + /// let (socket, other) = UnixDatagram::pair()?; + /// + /// socket.shutdown(Shutdown::Both)?; + /// + /// // NOTE: the following commented out code does NOT work as expected. + /// // Due to an underlying issue, the recv call will block indefinitely. + /// // See: https://github.com/tokio-rs/tokio/issues/1679 + /// //let mut buff = vec![0u8; 24]; + /// //let size = socket.recv(&mut buff).await?; + /// //assert_eq!(size, 0); + /// + /// let send_result = socket.send(b"hello world").await; + /// assert!(send_result.is_err()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.io.get_ref().shutdown(how) + } +} + +impl TryFrom<std::os::unix::net::UnixDatagram> for UnixDatagram { + type Error = io::Error; + + /// Consumes stream, returning the Tokio I/O object. + /// + /// This is equivalent to + /// [`UnixDatagram::from_std(stream)`](UnixDatagram::from_std). + fn try_from(stream: std::os::unix::net::UnixDatagram) -> Result<Self, Self::Error> { + Self::from_std(stream) + } +} + +impl fmt::Debug for UnixDatagram { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.io.get_ref().fmt(f) + } +} + +impl AsRawFd for UnixDatagram { + fn as_raw_fd(&self) -> RawFd { + self.io.get_ref().as_raw_fd() + } +} diff --git a/src/net/unix/incoming.rs b/src/net/unix/incoming.rs deleted file mode 100644 index af49360..0000000 --- a/src/net/unix/incoming.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::net::unix::{UnixListener, UnixStream}; - -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Stream of listeners -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Incoming<'a> { - inner: &'a mut UnixListener, -} - -impl Incoming<'_> { - pub(crate) fn new(listener: &mut UnixListener) -> Incoming<'_> { - Incoming { inner: listener } - } - - /// Attempts to poll `UnixStream` by polling inner `UnixListener` to accept - /// connection. - /// - /// If `UnixListener` isn't ready yet, `Poll::Pending` is returned and - /// current task will be notified by a waker. Otherwise `Poll::Ready` with - /// `Result` containing `UnixStream` will be returned. - pub fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<io::Result<UnixStream>> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Ok(socket)) - } -} - -#[cfg(feature = "stream")] -impl crate::stream::Stream for Incoming<'_> { - type Item = io::Result<UnixStream>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let (socket, _) = ready!(self.inner.poll_accept(cx))?; - Poll::Ready(Some(Ok(socket))) - } -} diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs index 9b76cb0..b272645 100644 --- a/src/net/unix/listener.rs +++ b/src/net/unix/listener.rs @@ -1,17 +1,15 @@ -use crate::future::poll_fn; use crate::io::PollEvented; -use crate::net::unix::{Incoming, UnixStream}; +use crate::net::unix::{SocketAddr, UnixStream}; -use mio::Ready; use std::convert::TryFrom; use std::fmt; use std::io; use std::os::unix::io::{AsRawFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; +use std::os::unix::net; use std::path::Path; use std::task::{Context, Poll}; -cfg_uds! { +cfg_net_unix! { /// A Unix socket which can accept connections from other Unix sockets. /// /// You can accept a new connection by using the [`accept`](`UnixListener::accept`) method. Alternatively `UnixListener` @@ -47,7 +45,7 @@ cfg_uds! { /// } /// ``` pub struct UnixListener { - io: PollEvented<mio_uds::UnixListener>, + io: PollEvented<mio::net::UnixListener>, } } @@ -60,12 +58,12 @@ impl UnixListener { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. pub fn bind<P>(path: P) -> io::Result<UnixListener> where P: AsRef<Path>, { - let listener = mio_uds::UnixListener::bind(path)?; + let listener = mio::net::UnixListener::bind(path)?; let io = PollEvented::new(listener)?; Ok(UnixListener { io }) } @@ -82,16 +80,16 @@ impl UnixListener { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. pub fn from_std(listener: net::UnixListener) -> io::Result<UnixListener> { - let listener = mio_uds::UnixListener::from_listener(listener)?; + let listener = mio::net::UnixListener::from_std(listener); let io = PollEvented::new(listener)?; Ok(UnixListener { io }) } /// Returns the local socket address of this listener. pub fn local_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().local_addr() + self.io.get_ref().local_addr().map(SocketAddr) } /// Returns the value of the `SO_ERROR` option. @@ -100,117 +98,62 @@ impl UnixListener { } /// Accepts a new incoming connection to this listener. - pub async fn accept(&mut self) -> io::Result<(UnixStream, SocketAddr)> { - poll_fn(|cx| self.poll_accept(cx)).await + pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let (mio, addr) = self + .io + .async_io(mio::Interest::READABLE, |sock| sock.accept()) + .await?; + + let addr = SocketAddr(addr); + let stream = UnixStream::new(mio)?; + Ok((stream, addr)) } - pub(crate) fn poll_accept( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<io::Result<(UnixStream, SocketAddr)>> { - let (io, addr) = ready!(self.poll_accept_std(cx))?; - - let io = mio_uds::UnixStream::from_stream(io)?; - Ok((UnixStream::new(io)?, addr)).into() - } - - fn poll_accept_std( - &mut self, - cx: &mut Context<'_>, - ) -> Poll<io::Result<(net::UnixStream, SocketAddr)>> { - ready!(self.io.poll_read_ready(cx, Ready::readable()))?; - - match self.io.get_ref().accept_std() { - Ok(None) => { - self.io.clear_read_ready(cx, Ready::readable())?; - Poll::Pending - } - Ok(Some((sock, addr))) => Ok((sock, addr)).into(), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, Ready::readable())?; - Poll::Pending + /// Polls to accept a new incoming connection to this listener. + /// + /// If there is no connection to accept, `Poll::Pending` is returned and + /// the current task will be notified by a waker. + /// + /// When ready, the most recent task that called `poll_accept` is notified. + /// The caller is responsble to ensure that `poll_accept` is called from a + /// single task. Failing to do this could result in tasks hanging. + pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UnixStream, SocketAddr)>> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + match self.io.get_ref().accept() { + Ok((sock, addr)) => { + let addr = SocketAddr(addr); + let sock = UnixStream::new(sock)?; + return Poll::Ready(Ok((sock, addr))); + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Err(err) => return Err(err).into(), } - Err(err) => Err(err).into(), } } - - /// Returns a stream over the connections being received on this listener. - /// - /// Note that `UnixListener` also directly implements `Stream`. - /// - /// The returned stream will never return `None` and will also not yield the - /// peer's `SocketAddr` structure. Iterating over it is equivalent to - /// calling accept in a loop. - /// - /// # Errors - /// - /// Note that accepting a connection can lead to various errors and not all - /// of them are necessarily fatal ‒ for example having too many open file - /// descriptors or the other side closing the connection while it waits in - /// an accept queue. These would terminate the stream if not handled in any - /// way. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::UnixListener; - /// use tokio::stream::StreamExt; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// let mut incoming = listener.incoming(); - /// - /// while let Some(stream) = incoming.next().await { - /// match stream { - /// Ok(stream) => { - /// println!("new client!"); - /// } - /// Err(e) => { /* connection failed */ } - /// } - /// } - /// } - /// ``` - pub fn incoming(&mut self) -> Incoming<'_> { - Incoming::new(self) - } } #[cfg(feature = "stream")] impl crate::stream::Stream for UnixListener { type Item = io::Result<UnixStream>; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Self::Item>> { + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let (socket, _) = ready!(self.poll_accept(cx))?; Poll::Ready(Some(Ok(socket))) } } -impl TryFrom<UnixListener> for mio_uds::UnixListener { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. - /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. - /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: UnixListener) -> Result<Self, Self::Error> { - value.io.into_inner() - } -} - -impl TryFrom<net::UnixListener> for UnixListener { +impl TryFrom<std::os::unix::net::UnixListener> for UnixListener { type Error = io::Error; /// Consumes stream, returning the tokio I/O object. /// /// This is equivalent to /// [`UnixListener::from_std(stream)`](UnixListener::from_std). - fn try_from(stream: net::UnixListener) -> io::Result<Self> { + fn try_from(stream: std::os::unix::net::UnixListener) -> io::Result<Self> { Self::from_std(stream) } } diff --git a/src/net/unix/mod.rs b/src/net/unix/mod.rs index ddba60d..19ee34a 100644 --- a/src/net/unix/mod.rs +++ b/src/net/unix/mod.rs @@ -1,16 +1,18 @@ //! Unix domain socket utility types -pub(crate) mod datagram; - -mod incoming; -pub use incoming::Incoming; +pub mod datagram; pub(crate) mod listener; -pub(crate) use listener::UnixListener; mod split; pub use split::{ReadHalf, WriteHalf}; +mod split_owned; +pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError}; + +mod socketaddr; +pub use socketaddr::SocketAddr; + pub(crate) mod stream; pub(crate) use stream::UnixStream; diff --git a/src/net/unix/socketaddr.rs b/src/net/unix/socketaddr.rs new file mode 100644 index 0000000..48f7b96 --- /dev/null +++ b/src/net/unix/socketaddr.rs @@ -0,0 +1,31 @@ +use std::fmt; +use std::path::Path; + +/// An address associated with a Tokio Unix socket. +pub struct SocketAddr(pub(super) mio::net::SocketAddr); + +impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn is_unnamed(&self) -> bool { + self.0.is_unnamed() + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn as_pathname(&self) -> Option<&Path> { + self.0.as_pathname() + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(fmt) + } +} diff --git a/src/net/unix/split.rs b/src/net/unix/split.rs index 9b9fa5e..460bbc1 100644 --- a/src/net/unix/split.rs +++ b/src/net/unix/split.rs @@ -8,20 +8,40 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::net::UnixStream; use std::io; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; -/// Read half of a `UnixStream`. +/// Borrowed read half of a [`UnixStream`], created by [`split`]. +/// +/// Reading from a `ReadHalf` is usually done using the convenience methods found on the +/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct ReadHalf<'a>(&'a UnixStream); -/// Write half of a `UnixStream`. +/// Borrowed write half of a [`UnixStream`], created by [`split`]. +/// +/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// shut down the UnixStream stream in the write direction. +/// +/// Writing to an `WriteHalf` is usually done using the convenience methods found +/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: UnixStream +/// [`split`]: UnixStream::split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude #[derive(Debug)] pub struct WriteHalf<'a>(&'a UnixStream); @@ -30,15 +50,11 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) { } impl AsyncRead for ReadHalf<'_> { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.0.poll_read_priv(cx, buf) } } diff --git a/src/net/unix/split_owned.rs b/src/net/unix/split_owned.rs new file mode 100644 index 0000000..ab23307 --- /dev/null +++ b/src/net/unix/split_owned.rs @@ -0,0 +1,182 @@ +//! `UnixStream` owned split support. +//! +//! A `UnixStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` +//! with the `UnixStream::into_split` method. `OwnedReadHalf` implements +//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. +//! +//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized +//! split has no associated overhead and enforces all invariants at the type +//! level. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::net::UnixStream; + +use std::error::Error; +use std::net::Shutdown; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; + +/// Owned read half of a [`UnixStream`], created by [`into_split`]. +/// +/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found +/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedReadHalf { + inner: Arc<UnixStream>, +} + +/// Owned write half of a [`UnixStream`], created by [`into_split`]. +/// +/// Note that in the [`AsyncWrite`] implementation of this type, +/// [`poll_shutdown`] will shut down the stream in the write direction. +/// Dropping the write half will also shut down the write half of the stream. +/// +/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods +/// found on the [`AsyncWriteExt`] trait. Examples import this trait through +/// [the prelude]. +/// +/// [`UnixStream`]: crate::net::UnixStream +/// [`into_split`]: crate::net::UnixStream::into_split() +/// [`AsyncWrite`]: trait@crate::io::AsyncWrite +/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown +/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt +/// [the prelude]: crate::prelude +#[derive(Debug)] +pub struct OwnedWriteHalf { + inner: Arc<UnixStream>, + shutdown_on_drop: bool, +} + +pub(crate) fn split_owned(stream: UnixStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + shutdown_on_drop: true, + }; + (read, write) +} + +pub(crate) fn reunite( + read: OwnedReadHalf, + write: OwnedWriteHalf, +) -> Result<UnixStream, ReuniteError> { + if Arc::ptr_eq(&read.inner, &write.inner) { + write.forget(); + // This unwrap cannot fail as the api does not allow creating more than two Arcs, + // and we just dropped the other half. + Ok(Arc::try_unwrap(read.inner).expect("UnixStream: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(read, write)) + } +} + +/// Error indicating that two halves were not from the same socket, and thus could +/// not be reunited. +#[derive(Debug)] +pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +impl OwnedReadHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedWriteHalf) -> Result<UnixStream, ReuniteError> { + reunite(self, other) + } +} + +impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + self.inner.poll_read_priv(cx, buf) + } +} + +impl OwnedWriteHalf { + /// Attempts to put the two halves of a `UnixStream` back together and + /// recover the original socket. Succeeds only if the two halves + /// originated from the same call to [`into_split`]. + /// + /// [`into_split`]: crate::net::UnixStream::into_split() + pub fn reunite(self, other: OwnedReadHalf) -> Result<UnixStream, ReuniteError> { + reunite(other, self) + } + + /// Destroy the write half, but don't close the write half of the stream + /// until the read half is dropped. If the read half has already been + /// dropped, this closes the stream. + pub fn forget(mut self) { + self.shutdown_on_drop = false; + drop(self); + } +} + +impl Drop for OwnedWriteHalf { + fn drop(&mut self) { + if self.shutdown_on_drop { + let _ = self.inner.shutdown(Shutdown::Write); + } + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.inner.poll_write_priv(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + // flush is a no-op + Poll::Ready(Ok(())) + } + + // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { + let res = self.inner.shutdown(Shutdown::Write); + if res.is_ok() { + Pin::into_inner(self).shutdown_on_drop = false; + } + res.into() + } +} + +impl AsRef<UnixStream> for OwnedReadHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} + +impl AsRef<UnixStream> for OwnedWriteHalf { + fn as_ref(&self) -> &UnixStream { + &*self.inner + } +} diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs index beae699..5138077 100644 --- a/src/net/unix/stream.rs +++ b/src/net/unix/stream.rs @@ -1,27 +1,28 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented}; +use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; use crate::net::unix::split::{split, ReadHalf, WriteHalf}; +use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::unix::ucred::{self, UCred}; +use crate::net::unix::SocketAddr; use std::convert::TryFrom; use std::fmt; use std::io::{self, Read, Write}; -use std::mem::MaybeUninit; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; -use std::os::unix::net::{self, SocketAddr}; +use std::os::unix::net; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_uds! { +cfg_net_unix! { /// A structure representing a connected Unix socket. /// /// This socket can be connected directly with `UnixStream::connect` or accepted /// from a listener with `UnixListener::incoming`. Additionally, a pair of /// anonymous Unix sockets can be created with `UnixStream::pair`. pub struct UnixStream { - io: PollEvented<mio_uds::UnixStream>, + io: PollEvented<mio::net::UnixStream>, } } @@ -35,7 +36,7 @@ impl UnixStream { where P: AsRef<Path>, { - let stream = mio_uds::UnixStream::connect(path)?; + let stream = mio::net::UnixStream::connect(path)?; let stream = UnixStream::new(stream)?; poll_fn(|cx| stream.io.poll_write_ready(cx)).await?; @@ -54,9 +55,9 @@ impl UnixStream { /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. pub fn from_std(stream: net::UnixStream) -> io::Result<UnixStream> { - let stream = mio_uds::UnixStream::from_stream(stream)?; + let stream = mio::net::UnixStream::from_std(stream); let io = PollEvented::new(stream)?; Ok(UnixStream { io }) @@ -68,26 +69,26 @@ impl UnixStream { /// communicating back and forth between one another. Each socket will /// be associated with the default event loop's handle. pub fn pair() -> io::Result<(UnixStream, UnixStream)> { - let (a, b) = mio_uds::UnixStream::pair()?; + let (a, b) = mio::net::UnixStream::pair()?; let a = UnixStream::new(a)?; let b = UnixStream::new(b)?; Ok((a, b)) } - pub(crate) fn new(stream: mio_uds::UnixStream) -> io::Result<UnixStream> { + pub(crate) fn new(stream: mio::net::UnixStream) -> io::Result<UnixStream> { let io = PollEvented::new(stream)?; Ok(UnixStream { io }) } /// Returns the socket address of the local half of this connection. pub fn local_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().local_addr() + self.io.get_ref().local_addr().map(SocketAddr) } /// Returns the socket address of the remote half of this connection. pub fn peer_addr(&self) -> io::Result<SocketAddr> { - self.io.get_ref().peer_addr() + self.io.get_ref().peer_addr().map(SocketAddr) } /// Returns effective credentials of the process which called `connect` or `pair`. @@ -109,24 +110,33 @@ impl UnixStream { self.io.get_ref().shutdown(how) } + // These lifetime markers also appear in the generated documentation, and make + // it more clear that this is a *borrowed* split. + #[allow(clippy::needless_lifetimes)] /// Split a `UnixStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. - pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) { + /// + /// This method is more efficient than [`into_split`], but the halves cannot be + /// moved into independently spawned tasks. + /// + /// [`into_split`]: Self::into_split() + pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) { split(self) } -} -impl TryFrom<UnixStream> for mio_uds::UnixStream { - type Error = io::Error; - - /// Consumes value, returning the mio I/O object. + /// Splits a `UnixStream` into a read half and a write half, which can be used + /// to read and write the stream concurrently. + /// + /// Unlike [`split`], the owned halves can be moved to separate tasks, however + /// this comes at the cost of a heap allocation. /// - /// See [`PollEvented::into_inner`] for more details about - /// resource deregistration that happens during the call. + /// **Note:** Dropping the write half will shut down the write half of the + /// stream. This is equivalent to calling [`shutdown(Write)`] on the `UnixStream`. /// - /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner - fn try_from(value: UnixStream) -> Result<Self, Self::Error> { - value.io.into_inner() + /// [`split`]: Self::split() + /// [`shutdown(Write)`]: fn@Self::shutdown + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + split_owned(self) } } @@ -143,15 +153,11 @@ impl TryFrom<net::UnixStream> for UnixStream { } impl AsyncRead for UnixStream { - unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.poll_read_priv(cx, buf) } } @@ -190,16 +196,30 @@ impl UnixStream { pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; - - match self.io.get_ref().read(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_read_ready(cx, mio::Ready::readable())?; - Poll::Pending + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + loop { + let ev = ready!(self.io.poll_read_ready(cx))?; + + // Safety: `UnixStream::read` will not peek at the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + match self.io.get_ref().read(b) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + Ok(n) => { + // Safety: We trust `UnixStream::read` to have filled up `n` bytes + // in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + return Poll::Ready(Ok(())); + } + Err(e) => return Poll::Ready(Err(e)), } - x => Poll::Ready(x), } } @@ -208,14 +228,15 @@ impl UnixStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - ready!(self.io.poll_write_ready(cx))?; - - match self.io.get_ref().write(buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - self.io.clear_write_ready(cx)?; - Poll::Pending + loop { + let ev = ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_readiness(ev); + } + x => return Poll::Ready(x), } - x => Poll::Ready(x), } } } diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs index 466aedc..ef214a7 100644 --- a/src/net/unix/ucred.rs +++ b/src/net/unix/ucred.rs @@ -4,9 +4,21 @@ use libc::{gid_t, uid_t}; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub struct UCred { /// UID (user ID) of the process - pub uid: uid_t, + uid: uid_t, /// GID (group ID) of the process - pub gid: gid_t, + gid: gid_t, +} + +impl UCred { + /// Gets UID (user ID) of the process. + pub fn uid(&self) -> uid_t { + self.uid + } + + /// Gets GID (group ID) of the process. + pub fn gid(&self) -> gid_t { + self.gid + } } #[cfg(any(target_os = "linux", target_os = "android"))] diff --git a/src/park/either.rs b/src/park/either.rs index 67f1e17..ee02ec1 100644 --- a/src/park/either.rs +++ b/src/park/either.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + use crate::park::{Park, Unpark}; use std::fmt; @@ -36,6 +38,13 @@ where Either::B(b) => b.park_timeout(duration).map_err(Either::B), } } + + fn shutdown(&mut self) { + match self { + Either::A(a) => a.shutdown(), + Either::B(b) => b.shutdown(), + } + } } impl<A, B> Unpark for Either<A, B> diff --git a/src/park/mod.rs b/src/park/mod.rs index 04d3051..5db26ce 100644 --- a/src/park/mod.rs +++ b/src/park/mod.rs @@ -34,17 +34,12 @@ //! * `park_timeout` does the same as `park` but allows specifying a maximum //! time to block the thread for. -cfg_resource_drivers! { - mod either; - pub(crate) use self::either::Either; +cfg_rt! { + pub(crate) mod either; } -mod thread; -pub(crate) use self::thread::ParkThread; - -cfg_block_on! { - pub(crate) use self::thread::{CachedParkThread, ParkError}; -} +#[cfg(any(feature = "rt", feature = "sync"))] +pub(crate) mod thread; use std::sync::Arc; use std::time::Duration; @@ -88,6 +83,9 @@ pub(crate) trait Park { /// an implementation detail. Refer to the documentation for the specific /// `Park` implementation fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>; + + /// Release all resources holded by the parker for proper leak-free shutdown + fn shutdown(&mut self); } /// Unblock a thread blocked by the associated `Park` instance. diff --git a/src/park/thread.rs b/src/park/thread.rs index 2e2397c..2725e45 100644 --- a/src/park/thread.rs +++ b/src/park/thread.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Condvar, Mutex}; use crate::park::{Park, Unpark}; @@ -65,6 +67,10 @@ impl Park for ParkThread { self.inner.park_timeout(duration); Ok(()) } + + fn shutdown(&mut self) { + self.inner.shutdown(); + } } // ==== impl Inner ==== @@ -83,7 +89,7 @@ impl Inner { } // Otherwise we need to coordinate going to sleep - let mut m = self.mutex.lock().unwrap(); + let mut m = self.mutex.lock(); match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { Ok(_) => {} @@ -133,7 +139,7 @@ impl Inner { return; } - let m = self.mutex.lock().unwrap(); + let m = self.mutex.lock(); match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) { Ok(_) => {} @@ -184,10 +190,14 @@ impl Inner { // Releasing `lock` before the call to `notify_one` means that when the // parked thread wakes it doesn't get woken only to have to wait for us // to release `lock`. - drop(self.mutex.lock().unwrap()); + drop(self.mutex.lock()); self.condvar.notify_one() } + + fn shutdown(&self) { + self.condvar.notify_all(); + } } impl Default for ParkThread { @@ -204,114 +214,133 @@ impl Unpark for UnparkThread { } } -cfg_block_on! { - use std::marker::PhantomData; - use std::rc::Rc; - - use std::mem; - use std::task::{RawWaker, RawWakerVTable, Waker}; +use std::future::Future; +use std::marker::PhantomData; +use std::mem; +use std::rc::Rc; +use std::task::{RawWaker, RawWakerVTable, Waker}; - /// Blocks the current thread using a condition variable. - #[derive(Debug)] - pub(crate) struct CachedParkThread { - _anchor: PhantomData<Rc<()>>, - } - - impl CachedParkThread { - /// Create a new `ParkThread` handle for the current thread. - /// - /// This type cannot be moved to other threads, so it should be created on - /// the thread that the caller intends to park. - pub(crate) fn new() -> CachedParkThread { - CachedParkThread { - _anchor: PhantomData, - } - } - - pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> { - self.with_current(|park_thread| park_thread.unpark()) - } +/// Blocks the current thread using a condition variable. +#[derive(Debug)] +pub(crate) struct CachedParkThread { + _anchor: PhantomData<Rc<()>>, +} - /// Get a reference to the `ParkThread` handle for this thread. - fn with_current<F, R>(&self, f: F) -> Result<R, ParkError> - where - F: FnOnce(&ParkThread) -> R, - { - CURRENT_PARKER.try_with(|inner| f(inner)) - .map_err(|_| ()) +impl CachedParkThread { + /// Create a new `ParkThread` handle for the current thread. + /// + /// This type cannot be moved to other threads, so it should be created on + /// the thread that the caller intends to park. + pub(crate) fn new() -> CachedParkThread { + CachedParkThread { + _anchor: PhantomData, } } - impl Park for CachedParkThread { - type Unpark = UnparkThread; - type Error = ParkError; + pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> { + self.with_current(|park_thread| park_thread.unpark()) + } - fn unpark(&self) -> Self::Unpark { - self.get_unpark().unwrap() - } + /// Get a reference to the `ParkThread` handle for this thread. + fn with_current<F, R>(&self, f: F) -> Result<R, ParkError> + where + F: FnOnce(&ParkThread) -> R, + { + CURRENT_PARKER.try_with(|inner| f(inner)).map_err(|_| ()) + } - fn park(&mut self) -> Result<(), Self::Error> { - self.with_current(|park_thread| park_thread.inner.park())?; - Ok(()) - } + pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError> { + use std::task::Context; + use std::task::Poll::Ready; - fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { - self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?; - Ok(()) - } - } + // `get_unpark()` should not return a Result + let waker = self.get_unpark()?.into_waker(); + let mut cx = Context::from_waker(&waker); + pin!(f); - impl UnparkThread { - pub(crate) fn into_waker(self) -> Waker { - unsafe { - let raw = unparker_to_raw_waker(self.inner); - Waker::from_raw(raw) + loop { + if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { + return Ok(v); } + + self.park()?; } } +} - impl Inner { - #[allow(clippy::wrong_self_convention)] - fn into_raw(this: Arc<Inner>) -> *const () { - Arc::into_raw(this) as *const () - } +impl Park for CachedParkThread { + type Unpark = UnparkThread; + type Error = ParkError; - unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> { - Arc::from_raw(ptr as *const Inner) - } + fn unpark(&self) -> Self::Unpark { + self.get_unpark().unwrap() } - unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker { - RawWaker::new( - Inner::into_raw(unparker), - &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker), - ) + fn park(&mut self) -> Result<(), Self::Error> { + self.with_current(|park_thread| park_thread.inner.park())?; + Ok(()) } - unsafe fn clone(raw: *const ()) -> RawWaker { - let unparker = Inner::from_raw(raw); + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?; + Ok(()) + } - // Increment the ref count - mem::forget(unparker.clone()); + fn shutdown(&mut self) { + let _ = self.with_current(|park_thread| park_thread.inner.shutdown()); + } +} - unparker_to_raw_waker(unparker) +impl UnparkThread { + pub(crate) fn into_waker(self) -> Waker { + unsafe { + let raw = unparker_to_raw_waker(self.inner); + Waker::from_raw(raw) + } } +} - unsafe fn drop_waker(raw: *const ()) { - let _ = Inner::from_raw(raw); +impl Inner { + #[allow(clippy::wrong_self_convention)] + fn into_raw(this: Arc<Inner>) -> *const () { + Arc::into_raw(this) as *const () } - unsafe fn wake(raw: *const ()) { - let unparker = Inner::from_raw(raw); - unparker.unpark(); + unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> { + Arc::from_raw(ptr as *const Inner) } +} - unsafe fn wake_by_ref(raw: *const ()) { - let unparker = Inner::from_raw(raw); - unparker.unpark(); +unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker { + RawWaker::new( + Inner::into_raw(unparker), + &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker), + ) +} - // We don't actually own a reference to the unparker - mem::forget(unparker); - } +unsafe fn clone(raw: *const ()) -> RawWaker { + let unparker = Inner::from_raw(raw); + + // Increment the ref count + mem::forget(unparker.clone()); + + unparker_to_raw_waker(unparker) +} + +unsafe fn drop_waker(raw: *const ()) { + let _ = Inner::from_raw(raw); +} + +unsafe fn wake(raw: *const ()) { + let unparker = Inner::from_raw(raw); + unparker.unpark(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let unparker = Inner::from_raw(raw); + unparker.unpark(); + + // We don't actually own a reference to the unparker + mem::forget(unparker); } diff --git a/src/process/mod.rs b/src/process/mod.rs index e04a435..ad64371 100644 --- a/src/process/mod.rs +++ b/src/process/mod.rs @@ -18,16 +18,15 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box<dyn std::error::Error>> { -//! // The usage is the same as with the standard library's `Command` type, however the value -//! // returned from `spawn` is a `Result` containing a `Future`. -//! let child = Command::new("echo").arg("hello").arg("world") -//! .spawn(); +//! // The usage is similar as with the standard library's `Command` type +//! let mut child = Command::new("echo") +//! .arg("hello") +//! .arg("world") +//! .spawn() +//! .expect("failed to spawn"); //! -//! // Make sure our child succeeded in spawning and process the result -//! let future = child.expect("failed to spawn"); -//! -//! // Await until the future (and the command) completes -//! let status = future.await?; +//! // Await until the command completes +//! let status = child.wait().await?; //! println!("the command exited with: {}", status); //! Ok(()) //! } @@ -83,8 +82,8 @@ //! //! // Ensure the child process is spawned in the runtime so it can //! // make progress on its own while we await for any output. -//! tokio::spawn(async { -//! let status = child.await +//! tokio::spawn(async move { +//! let status = child.wait().await //! .expect("child process encountered an error"); //! //! println!("child status was: {}", status); @@ -100,27 +99,52 @@ //! //! # Caveats //! +//! ## Dropping/Cancellation +//! //! Similar to the behavior to the standard library, and unlike the futures //! paradigm of dropping-implies-cancellation, a spawned process will, by //! default, continue to execute even after the `Child` handle has been dropped. //! -//! The `Command::kill_on_drop` method can be used to modify this behavior +//! The [`Command::kill_on_drop`] method can be used to modify this behavior //! and kill the child process if the `Child` wrapper is dropped before it //! has exited. //! +//! ## Unix Processes +//! +//! On Unix platforms processes must be "reaped" by their parent process after +//! they have exited in order to release all OS resources. A child process which +//! has exited, but has not yet been reaped by its parent is considered a "zombie" +//! process. Such processes continue to count against limits imposed by the system, +//! and having too many zombie processes present can prevent additional processes +//! from being spawned. +//! +//! The tokio runtime will, on a best-effort basis, attempt to reap and clean up +//! any process which it has spawned. No additional guarantees are made with regards +//! how quickly or how often this procedure will take place. +//! +//! It is recommended to avoid dropping a [`Child`] process handle before it has been +//! fully `await`ed if stricter cleanup guarantees are required. +//! //! [`Command`]: crate::process::Command +//! [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop +//! [`Child`]: crate::process::Child #[path = "unix/mod.rs"] #[cfg(unix)] mod imp; +#[cfg(unix)] +pub(crate) mod unix { + pub(crate) use super::imp::*; +} + #[path = "windows.rs"] #[cfg(windows)] mod imp; mod kill; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::process::kill::Kill; use std::ffi::OsStr; @@ -450,6 +474,26 @@ impl Command { /// By default, this value is assumed to be `false`, meaning the next spawned /// process will not be killed on drop, similar to the behavior of the standard /// library. + /// + /// # Caveats + /// + /// On Unix platforms processes must be "reaped" by their parent process after + /// they have exited in order to release all OS resources. A child process which + /// has exited, but has not yet been reaped by its parent is considered a "zombie" + /// process. Such processes continue to count against limits imposed by the system, + /// and having too many zombie processes present can prevent additional processes + /// from being spawned. + /// + /// Although issuing a `kill` signal to the child process is a synchronous + /// operation, the resulting zombie process cannot be `.await`ed inside of the + /// destructor to avoid blocking other tasks. The tokio runtime will, on a + /// best-effort basis, attempt to reap and clean up such processes in the + /// background, but makes no additional guarantees are made with regards + /// how quickly or how often this procedure will take place. + /// + /// If stronger guarantees are required, it is recommended to avoid dropping + /// a [`Child`] handle where possible, and instead utilize `child.wait().await` + /// or `child.kill().await` where possible. pub fn kill_on_drop(&mut self, kill_on_drop: bool) -> &mut Command { self.kill_on_drop = kill_on_drop; self @@ -534,16 +578,6 @@ impl Command { /// All I/O this child does will be associated with the current default /// event loop. /// - /// # Caveats - /// - /// Similar to the behavior to the standard library, and unlike the futures - /// paradigm of dropping-implies-cancellation, the spawned process will, by - /// default, continue to execute even after the `Child` handle has been dropped. - /// - /// The `Command::kill_on_drop` method can be used to modify this behavior - /// and kill the child process if the `Child` wrapper is dropped before it - /// has exited. - /// /// # Examples /// /// Basic usage: @@ -555,16 +589,55 @@ impl Command { /// Command::new("ls") /// .spawn() /// .expect("ls command failed to start") + /// .wait() /// .await /// .expect("ls command failed to run") /// } /// ``` + /// + /// # Caveats + /// + /// ## Dropping/Cancellation + /// + /// Similar to the behavior to the standard library, and unlike the futures + /// paradigm of dropping-implies-cancellation, a spawned process will, by + /// default, continue to execute even after the `Child` handle has been dropped. + /// + /// The [`Command::kill_on_drop`] method can be used to modify this behavior + /// and kill the child process if the `Child` wrapper is dropped before it + /// has exited. + /// + /// ## Unix Processes + /// + /// On Unix platforms processes must be "reaped" by their parent process after + /// they have exited in order to release all OS resources. A child process which + /// has exited, but has not yet been reaped by its parent is considered a "zombie" + /// process. Such processes continue to count against limits imposed by the system, + /// and having too many zombie processes present can prevent additional processes + /// from being spawned. + /// + /// The tokio runtime will, on a best-effort basis, attempt to reap and clean up + /// any process which it has spawned. No additional guarantees are made with regards + /// how quickly or how often this procedure will take place. + /// + /// It is recommended to avoid dropping a [`Child`] process handle before it has been + /// fully `await`ed if stricter cleanup guarantees are required. + /// + /// [`Command`]: crate::process::Command + /// [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop + /// [`Child`]: crate::process::Child + /// + /// # Errors + /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). pub fn spawn(&mut self) -> io::Result<Child> { imp::spawn_child(&mut self.std).map(|spawned_child| Child { - child: ChildDropGuard { + child: FusedChild::Child(ChildDropGuard { inner: spawned_child.child, kill_on_drop: self.kill_on_drop, - }, + }), stdin: spawned_child.stdin.map(|inner| ChildStdin { inner }), stdout: spawned_child.stdout.map(|inner| ChildStdout { inner }), stderr: spawned_child.stderr.map(|inner| ChildStderr { inner }), @@ -581,14 +654,20 @@ impl Command { /// All I/O this child does will be associated with the current default /// event loop. /// - /// If this future is dropped before the future resolves, then - /// the child will be killed, if it was spawned. + /// The destructor of the future returned by this function will kill + /// the child if [`kill_on_drop`] is set to true. + /// + /// [`kill_on_drop`]: fn@Self::kill_on_drop /// /// # Errors /// /// This future will return an error if the child process cannot be spawned /// or if there is an error while awaiting its status. /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). + /// /// # Examples /// /// Basic usage: @@ -602,6 +681,7 @@ impl Command { /// .await /// .expect("ls command failed to run") /// } + /// ``` pub fn status(&mut self) -> impl Future<Output = io::Result<ExitStatus>> { let child = self.spawn(); @@ -615,7 +695,7 @@ impl Command { child.stdout.take(); child.stderr.take(); - child.await + child.wait().await } } @@ -637,9 +717,19 @@ impl Command { /// All I/O this child does will be associated with the current default /// event loop. /// - /// If this future is dropped before the future resolves, then - /// the child will be killed, if it was spawned. + /// The destructor of the future returned by this function will kill + /// the child if [`kill_on_drop`] is set to true. + /// + /// [`kill_on_drop`]: fn@Self::kill_on_drop + /// + /// # Errors + /// + /// This future will return an error if the child process cannot be spawned + /// or if there is an error while awaiting its status. /// + /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock` + /// if the system process limit is reached (which includes other applications + /// running on the system). /// # Examples /// /// Basic usage: @@ -654,6 +744,7 @@ impl Command { /// .expect("ls command failed to run"); /// println!("stderr of ls: {:?}", output.stderr); /// } + /// ``` pub fn output(&mut self) -> impl Future<Output = io::Result<Output>> { self.std.stdout(Stdio::piped()); self.std.stderr(Stdio::piped()); @@ -725,12 +816,16 @@ where } } +/// Keeps track of the exit status of a child process without worrying about +/// polling the underlying futures even after they have completed. +#[derive(Debug)] +enum FusedChild { + Child(ChildDropGuard<imp::Child>), + Done(ExitStatus), +} + /// Representation of a child process spawned onto an event loop. /// -/// This type is also a future which will yield the `ExitStatus` of the -/// underlying child process. A `Child` here also provides access to information -/// like the OS-assigned identifier and the stdio streams. -/// /// # Caveats /// Similar to the behavior to the standard library, and unlike the futures /// paradigm of dropping-implies-cancellation, a spawned process will, by @@ -739,10 +834,9 @@ where /// The `Command::kill_on_drop` method can be used to modify this behavior /// and kill the child process if the `Child` wrapper is dropped before it /// has exited. -#[must_use = "futures do nothing unless polled"] #[derive(Debug)] pub struct Child { - child: ChildDropGuard<imp::Child>, + child: FusedChild, /// The handle for writing to the child's standard input (stdin), if it has /// been captured. @@ -758,34 +852,120 @@ pub struct Child { } impl Child { - /// Returns the OS-assigned process identifier associated with this child. - pub fn id(&self) -> u32 { - self.child.inner.id() + /// Returns the OS-assigned process identifier associated with this child + /// while it is still running. + /// + /// Once the child has been polled to completion this will return `None`. + /// This is done to avoid confusion on platforms like Unix where the OS + /// identifier could be reused once the process has completed. + pub fn id(&self) -> Option<u32> { + match &self.child { + FusedChild::Child(child) => Some(child.inner.id()), + FusedChild::Done(_) => None, + } + } + + /// Attempts to force the child to exit, but does not wait for the request + /// to take effect. + /// + /// On Unix platforms, this is the equivalent to sending a SIGKILL. Note + /// that on Unix platforms it is possible for a zombie process to remain + /// after a kill is sent; to avoid this, the caller should ensure that either + /// `child.wait().await` or `child.try_wait()` is invoked successfully. + pub fn start_kill(&mut self) -> io::Result<()> { + match &mut self.child { + FusedChild::Child(child) => child.kill(), + FusedChild::Done(_) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument: can't kill an exited process", + )), + } } /// Forces the child to exit. /// /// This is equivalent to sending a SIGKILL on unix platforms. - pub fn kill(&mut self) -> io::Result<()> { - self.child.kill() - } - - #[doc(hidden)] - #[deprecated(note = "please use `child.stdin` instead")] - pub fn stdin(&mut self) -> &mut Option<ChildStdin> { - &mut self.stdin + /// + /// If the child has to be killed remotely, it is possible to do it using + /// a combination of the select! macro and a oneshot channel. In the following + /// example, the child will run until completion unless a message is sent on + /// the oneshot channel. If that happens, the child is killed immediately + /// using the `.kill()` method. + /// + /// ```no_run + /// use tokio::process::Command; + /// use tokio::sync::oneshot::channel; + /// + /// #[tokio::main] + /// async fn main() { + /// let (send, recv) = channel::<()>(); + /// let mut child = Command::new("sleep").arg("1").spawn().unwrap(); + /// tokio::spawn(async move { send.send(()) }); + /// tokio::select! { + /// _ = child.wait() => {} + /// _ = recv => child.kill().await.expect("kill failed"), + /// } + /// } + /// ``` + pub async fn kill(&mut self) -> io::Result<()> { + self.start_kill()?; + self.wait().await?; + Ok(()) } - #[doc(hidden)] - #[deprecated(note = "please use `child.stdout` instead")] - pub fn stdout(&mut self) -> &mut Option<ChildStdout> { - &mut self.stdout + /// Waits for the child to exit completely, returning the status that it + /// exited with. This function will continue to have the same return value + /// after it has been called at least once. + /// + /// The stdin handle to the child process, if any, will be closed + /// before waiting. This helps avoid deadlock: it ensures that the + /// child does not block waiting for input from the parent, while + /// the parent waits for the child to exit. + pub async fn wait(&mut self) -> io::Result<ExitStatus> { + match &mut self.child { + FusedChild::Done(exit) => Ok(*exit), + FusedChild::Child(child) => { + let ret = child.await; + + if let Ok(exit) = ret { + self.child = FusedChild::Done(exit); + } + + ret + } + } } - #[doc(hidden)] - #[deprecated(note = "please use `child.stderr` instead")] - pub fn stderr(&mut self) -> &mut Option<ChildStderr> { - &mut self.stderr + /// Attempts to collect the exit status of the child if it has already + /// exited. + /// + /// This function will not block the calling thread and will only + /// check to see if the child process has exited or not. If the child has + /// exited then on Unix the process ID is reaped. This function is + /// guaranteed to repeatedly return a successful exit status so long as the + /// child has already exited. + /// + /// If the child has exited, then `Ok(Some(status))` is returned. If the + /// exit status is not available at this time then `Ok(None)` is returned. + /// If an error occurs, then that error is returned. + /// + /// Note that unlike `wait`, this function will not attempt to drop stdin, + /// nor will it wake the current task if the child exits. + pub fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + match &mut self.child { + FusedChild::Done(exit) => Ok(Some(*exit)), + FusedChild::Child(guard) => { + let ret = guard.inner.try_wait(); + + if let Ok(Some(exit)) = ret { + // Avoid the overhead of trying to kill a reaped process + guard.kill_on_drop = false; + self.child = FusedChild::Done(exit); + } + + ret + } + } } /// Returns a future that will resolve to an `Output`, containing the exit @@ -819,7 +999,7 @@ impl Child { let stdout_fut = read_to_end(self.stdout.take()); let stderr_fut = read_to_end(self.stderr.take()); - let (status, stdout, stderr) = try_join3(self, stdout_fut, stderr_fut).await?; + let (status, stdout, stderr) = try_join3(self.wait(), stdout_fut, stderr_fut).await?; Ok(Output { status, @@ -829,14 +1009,6 @@ impl Child { } } -impl Future for Child { - type Output = io::Result<ExitStatus>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - Pin::new(&mut self.child).poll(cx) - } -} - /// The standard input stream for spawned children. /// /// This type implements the `AsyncWrite` trait to pass data to the stdin handle of @@ -883,31 +1055,21 @@ impl AsyncWrite for ChildStdin { } impl AsyncRead for ChildStdout { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L314 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut self.inner).poll_read(cx, buf) } } impl AsyncRead for ChildStderr { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L375 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut self.inner).poll_read(cx, buf) } } diff --git a/src/process/unix/driver.rs b/src/process/unix/driver.rs new file mode 100644 index 0000000..9a16cad --- /dev/null +++ b/src/process/unix/driver.rs @@ -0,0 +1,156 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Process driver + +use crate::park::Park; +use crate::process::unix::orphan::ReapOrphanQueue; +use crate::process::unix::GlobalOrphanQueue; +use crate::signal::unix::driver::Driver as SignalDriver; +use crate::signal::unix::{signal_with_handle, InternalStream, Signal, SignalKind}; +use crate::sync::mpsc::error::TryRecvError; + +use std::io; +use std::time::Duration; + +/// Responsible for cleaning up orphaned child processes on Unix platforms. +#[derive(Debug)] +pub(crate) struct Driver { + park: SignalDriver, + inner: CoreDriver<Signal, GlobalOrphanQueue>, +} + +#[derive(Debug)] +struct CoreDriver<S, Q> { + sigchild: S, + orphan_queue: Q, +} + +// ===== impl CoreDriver ===== + +impl<S, Q> CoreDriver<S, Q> +where + S: InternalStream, + Q: ReapOrphanQueue, +{ + fn got_signal(&mut self) -> bool { + match self.sigchild.try_recv() { + Ok(()) => true, + Err(TryRecvError::Empty) => false, + Err(TryRecvError::Closed) => panic!("signal was deregistered"), + } + } + + fn process(&mut self) { + if self.got_signal() { + // Drain all notifications which may have been buffered + // so we can try to reap all orphans in one batch + while self.got_signal() {} + + self.orphan_queue.reap_orphans(); + } + } +} + +// ===== impl Driver ===== + +impl Driver { + /// Creates a new signal `Driver` instance that delegates wakeups to `park`. + pub(crate) fn new(park: SignalDriver) -> io::Result<Self> { + let sigchild = signal_with_handle(SignalKind::child(), park.handle())?; + let inner = CoreDriver { + sigchild, + orphan_queue: GlobalOrphanQueue, + }; + + Ok(Self { park, inner }) + } +} + +// ===== impl Park for Driver ===== + +impl Park for Driver { + type Unpark = <SignalDriver as Park>::Unpark; + type Error = io::Error; + + fn unpark(&self) -> Self::Unpark { + self.park.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park.park()?; + self.inner.process(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park.park_timeout(duration)?; + self.inner.process(); + Ok(()) + } + + fn shutdown(&mut self) { + self.park.shutdown() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::process::unix::orphan::test::MockQueue; + use crate::sync::mpsc::error::TryRecvError; + use std::task::{Context, Poll}; + + struct MockStream { + total_try_recv: usize, + values: Vec<Option<()>>, + } + + impl MockStream { + fn new(values: Vec<Option<()>>) -> Self { + Self { + total_try_recv: 0, + values, + } + } + } + + impl InternalStream for MockStream { + fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll<Option<()>> { + unimplemented!(); + } + + fn try_recv(&mut self) -> Result<(), TryRecvError> { + self.total_try_recv += 1; + match self.values.remove(0) { + Some(()) => Ok(()), + None => Err(TryRecvError::Empty), + } + } + } + + #[test] + fn no_reap_if_no_signal() { + let mut driver = CoreDriver { + sigchild: MockStream::new(vec![None]), + orphan_queue: MockQueue::<()>::new(), + }; + + driver.process(); + + assert_eq!(1, driver.sigchild.total_try_recv); + assert_eq!(0, driver.orphan_queue.total_reaps.get()); + } + + #[test] + fn coalesce_signals_before_reaping() { + let mut driver = CoreDriver { + sigchild: MockStream::new(vec![Some(()), Some(()), None]), + orphan_queue: MockQueue::<()>::new(), + }; + + driver.process(); + + assert_eq!(3, driver.sigchild.total_try_recv); + assert_eq!(1, driver.orphan_queue.total_reaps.get()); + } +} diff --git a/src/process/unix/mod.rs b/src/process/unix/mod.rs index c25d989..db9d592 100644 --- a/src/process/unix/mod.rs +++ b/src/process/unix/mod.rs @@ -21,8 +21,10 @@ //! processes in general aren't scalable (e.g. millions) so it shouldn't be that //! bad in theory... -mod orphan; -use orphan::{OrphanQueue, OrphanQueueImpl, Wait}; +pub(crate) mod driver; + +pub(crate) mod orphan; +use orphan::{OrphanQueue, OrphanQueueImpl, ReapOrphanQueue, Wait}; mod reap; use reap::Reaper; @@ -32,19 +34,18 @@ use crate::process::kill::Kill; use crate::process::SpawnedChild; use crate::signal::unix::{signal, Signal, SignalKind}; -use mio::event::Evented; -use mio::unix::{EventedFd, UnixReady}; -use mio::{Poll as MioPoll, PollOpt, Ready, Token}; +use mio::event::Source; +use mio::unix::SourceFd; use std::fmt; use std::future::Future; use std::io; use std::os::unix::io::{AsRawFd, RawFd}; use std::pin::Pin; -use std::process::ExitStatus; +use std::process::{Child as StdChild, ExitStatus}; use std::task::Context; use std::task::Poll; -impl Wait for std::process::Child { +impl Wait for StdChild { fn id(&self) -> u32 { self.id() } @@ -54,17 +55,17 @@ impl Wait for std::process::Child { } } -impl Kill for std::process::Child { +impl Kill for StdChild { fn kill(&mut self) -> io::Result<()> { self.kill() } } lazy_static::lazy_static! { - static ref ORPHAN_QUEUE: OrphanQueueImpl<std::process::Child> = OrphanQueueImpl::new(); + static ref ORPHAN_QUEUE: OrphanQueueImpl<StdChild> = OrphanQueueImpl::new(); } -struct GlobalOrphanQueue; +pub(crate) struct GlobalOrphanQueue; impl fmt::Debug for GlobalOrphanQueue { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -72,19 +73,21 @@ impl fmt::Debug for GlobalOrphanQueue { } } -impl OrphanQueue<std::process::Child> for GlobalOrphanQueue { - fn push_orphan(&self, orphan: std::process::Child) { - ORPHAN_QUEUE.push_orphan(orphan) - } - +impl ReapOrphanQueue for GlobalOrphanQueue { fn reap_orphans(&self) { ORPHAN_QUEUE.reap_orphans() } } +impl OrphanQueue<StdChild> for GlobalOrphanQueue { + fn push_orphan(&self, orphan: StdChild) { + ORPHAN_QUEUE.push_orphan(orphan) + } +} + #[must_use = "futures do nothing unless polled"] pub(crate) struct Child { - inner: Reaper<std::process::Child, GlobalOrphanQueue, Signal>, + inner: Reaper<StdChild, GlobalOrphanQueue, Signal>, } impl fmt::Debug for Child { @@ -117,6 +120,10 @@ impl Child { pub(crate) fn id(&self) -> u32 { self.inner.id() } + + pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + self.inner.inner_mut().try_wait() + } } impl Kill for Child { @@ -169,32 +176,30 @@ where } } -impl<T> Evented for Fd<T> +impl<T> Source for Fd<T> where T: AsRawFd, { fn register( - &self, - poll: &MioPoll, - token: Token, - interest: Ready, - opts: PollOpt, + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, ) -> io::Result<()> { - EventedFd(&self.as_raw_fd()).register(poll, token, interest | UnixReady::hup(), opts) + SourceFd(&self.as_raw_fd()).register(registry, token, interest) } fn reregister( - &self, - poll: &MioPoll, - token: Token, - interest: Ready, - opts: PollOpt, + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, ) -> io::Result<()> { - EventedFd(&self.as_raw_fd()).reregister(poll, token, interest | UnixReady::hup(), opts) + SourceFd(&self.as_raw_fd()).reregister(registry, token, interest) } - fn deregister(&self, poll: &MioPoll) -> io::Result<()> { - EventedFd(&self.as_raw_fd()).deregister(poll) + fn deregister(&mut self, registry: &mio::Registry) -> io::Result<()> { + SourceFd(&self.as_raw_fd()).deregister(registry) } } diff --git a/src/process/unix/orphan.rs b/src/process/unix/orphan.rs index 6c449a9..8a1e127 100644 --- a/src/process/unix/orphan.rs +++ b/src/process/unix/orphan.rs @@ -20,23 +20,29 @@ impl<T: Wait> Wait for &mut T { } } -/// An interface for queueing up an orphaned process so that it can be reaped. -pub(crate) trait OrphanQueue<T> { - /// Adds an orphan to the queue. - fn push_orphan(&self, orphan: T); +/// An interface for reaping a set of orphaned processes. +pub(crate) trait ReapOrphanQueue { /// Attempts to reap every process in the queue, ignoring any errors and /// enqueueing any orphans which have not yet exited. fn reap_orphans(&self); } +impl<T: ReapOrphanQueue> ReapOrphanQueue for &T { + fn reap_orphans(&self) { + (**self).reap_orphans() + } +} + +/// An interface for queueing up an orphaned process so that it can be reaped. +pub(crate) trait OrphanQueue<T>: ReapOrphanQueue { + /// Adds an orphan to the queue. + fn push_orphan(&self, orphan: T); +} + impl<T, O: OrphanQueue<T>> OrphanQueue<T> for &O { fn push_orphan(&self, orphan: T) { (**self).push_orphan(orphan); } - - fn reap_orphans(&self) { - (**self).reap_orphans() - } } /// An implementation of `OrphanQueue`. @@ -62,42 +68,62 @@ impl<T: Wait> OrphanQueue<T> for OrphanQueueImpl<T> { fn push_orphan(&self, orphan: T) { self.queue.lock().unwrap().push(orphan) } +} +impl<T: Wait> ReapOrphanQueue for OrphanQueueImpl<T> { fn reap_orphans(&self) { let mut queue = self.queue.lock().unwrap(); let queue = &mut *queue; - let mut i = 0; - while i < queue.len() { + for i in (0..queue.len()).rev() { match queue[i].try_wait() { - Ok(Some(_)) => {} - Err(_) => { - // TODO: bubble up error some how. Is this an internal bug? - // Shoudl we panic? Is it OK for this to be silently - // dropped? - } - // Still not done yet - Ok(None) => { - i += 1; - continue; + Ok(None) => {} + Ok(Some(_)) | Err(_) => { + // The stdlib handles interruption errors (EINTR) when polling a child process. + // All other errors represent invalid inputs or pids that have already been + // reaped, so we can drop the orphan in case an error is raised. + queue.swap_remove(i); } } - - queue.remove(i); } } } #[cfg(all(test, not(loom)))] -mod test { - use super::Wait; - use super::{OrphanQueue, OrphanQueueImpl}; - use std::cell::Cell; +pub(crate) mod test { + use super::*; + use std::cell::{Cell, RefCell}; use std::io; use std::os::unix::process::ExitStatusExt; use std::process::ExitStatus; use std::rc::Rc; + pub(crate) struct MockQueue<W> { + pub(crate) all_enqueued: RefCell<Vec<W>>, + pub(crate) total_reaps: Cell<usize>, + } + + impl<W> MockQueue<W> { + pub(crate) fn new() -> Self { + Self { + all_enqueued: RefCell::new(Vec::new()), + total_reaps: Cell::new(0), + } + } + } + + impl<W> OrphanQueue<W> for MockQueue<W> { + fn push_orphan(&self, orphan: W) { + self.all_enqueued.borrow_mut().push(orphan); + } + } + + impl<W> ReapOrphanQueue for MockQueue<W> { + fn reap_orphans(&self) { + self.total_reaps.set(self.total_reaps.get() + 1); + } + } + struct MockWait { total_waits: Rc<Cell<usize>>, num_wait_until_status: usize, diff --git a/src/process/unix/reap.rs b/src/process/unix/reap.rs index 8963805..de483c4 100644 --- a/src/process/unix/reap.rs +++ b/src/process/unix/reap.rs @@ -1,6 +1,6 @@ use crate::process::imp::orphan::{OrphanQueue, Wait}; use crate::process::kill::Kill; -use crate::signal::unix::Signal; +use crate::signal::unix::InternalStream; use std::future::Future; use std::io; @@ -23,17 +23,6 @@ where signal: S, } -// Work around removal of `futures_core` dependency -pub(crate) trait Stream: Unpin { - fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>; -} - -impl Stream for Signal { - fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - Signal::poll_recv(self, cx) - } -} - impl<W, Q, S> Deref for Reaper<W, Q, S> where W: Wait + Unpin, @@ -63,7 +52,7 @@ where self.inner.as_ref().expect("inner has gone away") } - fn inner_mut(&mut self) -> &mut W { + pub(crate) fn inner_mut(&mut self) -> &mut W { self.inner.as_mut().expect("inner has gone away") } } @@ -72,7 +61,7 @@ impl<W, Q, S> Future for Reaper<W, Q, S> where W: Wait + Unpin, Q: OrphanQueue<W> + Unpin, - S: Stream, + S: InternalStream, { type Output = io::Result<ExitStatus>; @@ -80,10 +69,8 @@ where loop { // If the child hasn't exited yet, then it's our responsibility to // ensure the current task gets notified when it might be able to - // make progress. - // - // As described in `spawn` above, we just indicate that we can - // next make progress once a SIGCHLD is received. + // make progress. We can use the delivery of a SIGCHLD signal as a + // sign that we can potentially make progress. // // However, we will register for a notification on the next signal // BEFORE we poll the child. Otherwise it is possible that the child @@ -99,7 +86,6 @@ where // should not cause significant issues with parent futures. let registered_interest = self.signal.poll_recv(cx).is_pending(); - self.orphan_queue.reap_orphans(); if let Some(status) = self.inner_mut().try_wait()? { return Poll::Ready(Ok(status)); } @@ -147,8 +133,9 @@ where mod test { use super::*; + use crate::process::unix::orphan::test::MockQueue; + use crate::sync::mpsc::error::TryRecvError; use futures::future::FutureExt; - use std::cell::{Cell, RefCell}; use std::os::unix::process::ExitStatusExt; use std::process::ExitStatus; use std::task::Context; @@ -211,7 +198,7 @@ mod test { } } - impl Stream for MockStream { + impl InternalStream for MockStream { fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll<Option<()>> { self.total_polls += 1; match self.values.remove(0) { @@ -219,29 +206,9 @@ mod test { None => Poll::Pending, } } - } - struct MockQueue<W> { - all_enqueued: RefCell<Vec<W>>, - total_reaps: Cell<usize>, - } - - impl<W> MockQueue<W> { - fn new() -> Self { - Self { - all_enqueued: RefCell::new(Vec::new()), - total_reaps: Cell::new(0), - } - } - } - - impl<W: Wait> OrphanQueue<W> for MockQueue<W> { - fn push_orphan(&self, orphan: W) { - self.all_enqueued.borrow_mut().push(orphan); - } - - fn reap_orphans(&self) { - self.total_reaps.set(self.total_reaps.get() + 1); + fn try_recv(&mut self) -> Result<(), TryRecvError> { + unimplemented!(); } } @@ -262,7 +229,7 @@ mod test { assert!(grim.poll_unpin(&mut context).is_pending()); assert_eq!(1, grim.signal.total_polls); assert_eq!(1, grim.total_waits); - assert_eq!(1, grim.orphan_queue.total_reaps.get()); + assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); // Not yet exited, couldn't register interest the first time @@ -270,7 +237,7 @@ mod test { assert!(grim.poll_unpin(&mut context).is_pending()); assert_eq!(3, grim.signal.total_polls); assert_eq!(3, grim.total_waits); - assert_eq!(3, grim.orphan_queue.total_reaps.get()); + assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); // Exited @@ -283,7 +250,7 @@ mod test { } assert_eq!(4, grim.signal.total_polls); assert_eq!(4, grim.total_waits); - assert_eq!(4, grim.orphan_queue.total_reaps.get()); + assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); } diff --git a/src/process/windows.rs b/src/process/windows.rs index cbe2fa7..1aa6c89 100644 --- a/src/process/windows.rs +++ b/src/process/windows.rs @@ -20,24 +20,19 @@ use crate::process::kill::Kill; use crate::process::SpawnedChild; use crate::sync::oneshot; -use mio_named_pipes::NamedPipe; +use mio::windows::NamedPipe; use std::fmt; use std::future::Future; use std::io; -use std::os::windows::prelude::*; -use std::os::windows::process::ExitStatusExt; +use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle}; use std::pin::Pin; use std::process::{Child as StdChild, Command as StdCommand, ExitStatus}; use std::ptr; use std::task::Context; use std::task::Poll; -use winapi::shared::minwindef::FALSE; -use winapi::shared::winerror::WAIT_TIMEOUT; use winapi::um::handleapi::INVALID_HANDLE_VALUE; -use winapi::um::processthreadsapi::GetExitCodeProcess; -use winapi::um::synchapi::WaitForSingleObject; use winapi::um::threadpoollegacyapiset::UnregisterWaitEx; -use winapi::um::winbase::{RegisterWaitForSingleObject, INFINITE, WAIT_OBJECT_0}; +use winapi::um::winbase::{RegisterWaitForSingleObject, INFINITE}; use winapi::um::winnt::{BOOLEAN, HANDLE, PVOID, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE}; #[must_use = "futures do nothing unless polled"] @@ -86,6 +81,10 @@ impl Child { pub(crate) fn id(&self) -> u32 { self.child.id() } + + pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> { + self.child.try_wait() + } } impl Kill for Child { @@ -106,11 +105,11 @@ impl Future for Child { Poll::Ready(Err(_)) => panic!("should not be canceled"), Poll::Pending => return Poll::Pending, } - let status = try_wait(&inner.child)?.expect("not ready yet"); + let status = inner.try_wait()?.expect("not ready yet"); return Poll::Ready(Ok(status)); } - if let Some(e) = try_wait(&inner.child)? { + if let Some(e) = inner.try_wait()? { return Poll::Ready(Ok(e)); } let (tx, rx) = oneshot::channel(); @@ -157,23 +156,6 @@ unsafe extern "system" fn callback(ptr: PVOID, _timer_fired: BOOLEAN) { let _ = complete.take().unwrap().send(()); } -pub(crate) fn try_wait(child: &StdChild) -> io::Result<Option<ExitStatus>> { - unsafe { - match WaitForSingleObject(child.as_raw_handle(), 0) { - WAIT_OBJECT_0 => {} - WAIT_TIMEOUT => return Ok(None), - _ => return Err(io::Error::last_os_error()), - } - let mut status = 0; - let rc = GetExitCodeProcess(child.as_raw_handle(), &mut status); - if rc == FALSE { - Err(io::Error::last_os_error()) - } else { - Ok(Some(ExitStatus::from_raw(status))) - } - } -} - pub(crate) type ChildStdin = PollEvented<NamedPipe>; pub(crate) type ChildStdout = PollEvented<NamedPipe>; pub(crate) type ChildStderr = PollEvented<NamedPipe>; diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs index 7e1c257..5ca8467 100644 --- a/src/runtime/basic_scheduler.rs +++ b/src/runtime/basic_scheduler.rs @@ -1,22 +1,35 @@ +use crate::future::poll_fn; +use crate::loom::sync::Mutex; use crate::park::{Park, Unpark}; -use crate::runtime; use crate::runtime::task::{self, JoinHandle, Schedule, Task}; -use crate::util::linked_list::LinkedList; -use crate::util::{waker_ref, Wake}; +use crate::sync::notify::Notify; +use crate::util::linked_list::{Link, LinkedList}; +use crate::util::{waker_ref, Wake, WakerRef}; use std::cell::RefCell; use std::collections::VecDeque; use std::fmt; use std::future::Future; -use std::sync::{Arc, Mutex}; -use std::task::Poll::Ready; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; use std::time::Duration; /// Executes tasks on the current thread -pub(crate) struct BasicScheduler<P> -where - P: Park, -{ +pub(crate) struct BasicScheduler<P: Park> { + /// Inner state guarded by a mutex that is shared + /// between all `block_on` calls. + inner: Mutex<Option<Inner<P>>>, + + /// Notifier for waking up other threads to steal the + /// parker. + notify: Notify, + + /// Sendable task spawner + spawner: Spawner, +} + +/// The inner scheduler that owns the task queue and the main parker P. +struct Inner<P: Park> { /// Scheduler run queue /// /// When the scheduler is executed, the queue is removed from `self` and @@ -42,7 +55,7 @@ pub(crate) struct Spawner { struct Tasks { /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>>, + owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, /// Local run queue. /// @@ -59,7 +72,7 @@ struct Shared { unpark: Box<dyn Unpark>, } -/// Thread-local context +/// Thread-local context. struct Context { /// Shared scheduler state shared: Arc<Shared>, @@ -68,38 +81,43 @@ struct Context { tasks: RefCell<Tasks>, } -/// Initial queue capacity +/// Initial queue capacity. const INITIAL_CAPACITY: usize = 64; /// Max number of tasks to poll per tick. const MAX_TASKS_PER_TICK: usize = 61; -/// How often ot check the remote queue first +/// How often to check the remote queue first. const REMOTE_FIRST_INTERVAL: u8 = 31; -// Tracks the current BasicScheduler +// Tracks the current BasicScheduler. scoped_thread_local!(static CURRENT: Context); -impl<P> BasicScheduler<P> -where - P: Park, -{ +impl<P: Park> BasicScheduler<P> { pub(crate) fn new(park: P) -> BasicScheduler<P> { let unpark = Box::new(park.unpark()); - BasicScheduler { + let spawner = Spawner { + shared: Arc::new(Shared { + queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + unpark: unpark as Box<dyn Unpark>, + }), + }; + + let inner = Mutex::new(Some(Inner { tasks: Some(Tasks { owned: LinkedList::new(), queue: VecDeque::with_capacity(INITIAL_CAPACITY), }), - spawner: Spawner { - shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), - unpark: unpark as Box<dyn Unpark>, - }), - }, + spawner: spawner.clone(), tick: 0, park, + })); + + BasicScheduler { + inner, + notify: Notify::new(), + spawner, } } @@ -116,13 +134,57 @@ where self.spawner.spawn(future) } - pub(crate) fn block_on<F>(&mut self, future: F) -> F::Output - where - F: Future, - { + pub(crate) fn block_on<F: Future>(&self, future: F) -> F::Output { + pin!(future); + + // Attempt to steal the dedicated parker and block_on the future if we can there, + // othwerwise, lets select on a notification that the parker is available + // or the future is complete. + loop { + if let Some(inner) = &mut self.take_inner() { + return inner.block_on(future); + } else { + let mut enter = crate::runtime::enter(false); + + let notified = self.notify.notified(); + pin!(notified); + + if let Some(out) = enter + .block_on(poll_fn(|cx| { + if notified.as_mut().poll(cx).is_ready() { + return Ready(None); + } + + if let Ready(out) = future.as_mut().poll(cx) { + return Ready(Some(out)); + } + + Pending + })) + .expect("Failed to `Enter::block_on`") + { + return out; + } + } + } + } + + fn take_inner(&self) -> Option<InnerGuard<'_, P>> { + let inner = self.inner.lock().take()?; + + Some(InnerGuard { + inner: Some(inner), + basic_scheduler: &self, + }) + } +} + +impl<P: Park> Inner<P> { + /// Block on the future provided and drive the runtime's driver. + fn block_on<F: Future>(&mut self, future: F) -> F::Output { enter(self, |scheduler, context| { - let _enter = runtime::enter(false); - let waker = waker_ref(&scheduler.spawner.shared); + let _enter = crate::runtime::enter(false); + let waker = scheduler.spawner.waker_ref(); let mut cx = std::task::Context::from_waker(&waker); pin!(future); @@ -177,16 +239,16 @@ where /// Enter the scheduler context. This sets the queue and other necessary /// scheduler state in the thread-local -fn enter<F, R, P>(scheduler: &mut BasicScheduler<P>, f: F) -> R +fn enter<F, R, P>(scheduler: &mut Inner<P>, f: F) -> R where - F: FnOnce(&mut BasicScheduler<P>, &Context) -> R, + F: FnOnce(&mut Inner<P>, &Context) -> R, P: Park, { // Ensures the run queue is placed back in the `BasicScheduler` instance // once `block_on` returns.` struct Guard<'a, P: Park> { context: Option<Context>, - scheduler: &'a mut BasicScheduler<P>, + scheduler: &'a mut Inner<P>, } impl<P: Park> Drop for Guard<'_, P> { @@ -213,12 +275,18 @@ where CURRENT.set(context, || f(scheduler, context)) } -impl<P> Drop for BasicScheduler<P> -where - P: Park, -{ +impl<P: Park> Drop for BasicScheduler<P> { fn drop(&mut self) { - enter(self, |scheduler, context| { + // Avoid a double panic if we are currently panicking and + // the lock may be poisoned. + + let mut inner = match self.inner.lock().take() { + Some(inner) => inner, + None if std::thread::panicking() => return, + None => panic!("Oh no! We never placed the Inner state back, this is a bug!"), + }; + + enter(&mut inner, |scheduler, context| { // Loop required here to ensure borrow is dropped between iterations #[allow(clippy::while_let_loop)] loop { @@ -236,7 +304,7 @@ where } // Drain remote queue - for task in scheduler.spawner.shared.queue.lock().unwrap().drain(..) { + for task in scheduler.spawner.shared.queue.lock().drain(..) { task.shutdown(); } @@ -266,7 +334,11 @@ impl Spawner { } fn pop(&self) -> Option<task::Notified<Arc<Shared>>> { - self.shared.queue.lock().unwrap().pop_front() + self.shared.queue.lock().pop_front() + } + + fn waker_ref(&self) -> WakerRef<'_> { + waker_ref(&self.shared) } } @@ -307,7 +379,7 @@ impl Schedule for Arc<Shared> { cx.tasks.borrow_mut().queue.push_back(task); } _ => { - self.queue.lock().unwrap().push_back(task); + self.queue.lock().push_back(task); self.unpark.unpark(); } }); @@ -324,3 +396,37 @@ impl Wake for Shared { arc_self.unpark.unpark(); } } + +// ===== InnerGuard ===== + +/// Used to ensure we always place the Inner value +/// back into its slot in `BasicScheduler`, even if the +/// future panics. +struct InnerGuard<'a, P: Park> { + inner: Option<Inner<P>>, + basic_scheduler: &'a BasicScheduler<P>, +} + +impl<P: Park> InnerGuard<'_, P> { + fn block_on<F: Future>(&mut self, future: F) -> F::Output { + // The only time inner gets set to `None` is if we have dropped + // already so this unwrap is safe. + self.inner.as_mut().unwrap().block_on(future) + } +} + +impl<P: Park> Drop for InnerGuard<'_, P> { + fn drop(&mut self) { + if let Some(scheduler) = self.inner.take() { + let mut lock = self.basic_scheduler.inner.lock(); + + // Replace old scheduler back into the state to allow + // other threads to pick it up and drive it. + lock.replace(scheduler); + + // Wake up other possible threads that could steal + // the dedicated parker P. + self.basic_scheduler.notify.notify_one() + } + } +} diff --git a/src/runtime/blocking/mod.rs b/src/runtime/blocking/mod.rs index 0b36a75..fece3c2 100644 --- a/src/runtime/blocking/mod.rs +++ b/src/runtime/blocking/mod.rs @@ -3,22 +3,20 @@ //! shells. This isolates the complexity of dealing with conditional //! compilation. -cfg_blocking_impl! { - mod pool; - pub(crate) use pool::{spawn_blocking, try_spawn_blocking, BlockingPool, Spawner}; +mod pool; +pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; - mod schedule; - mod shutdown; - pub(crate) mod task; +mod schedule; +mod shutdown; +pub(crate) mod task; - use crate::runtime::Builder; - - pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool { - BlockingPool::new(builder, thread_cap) +use crate::runtime::Builder; - } +pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool { + BlockingPool::new(builder, thread_cap) } +/* cfg_not_blocking_impl! { use crate::runtime::Builder; use std::time::Duration; @@ -41,3 +39,4 @@ cfg_not_blocking_impl! { } } } +*/ diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs index 40d417b..2967a10 100644 --- a/src/runtime/blocking/pool.rs +++ b/src/runtime/blocking/pool.rs @@ -5,9 +5,13 @@ use crate::loom::thread; use crate::runtime::blocking::schedule::NoopSchedule; use crate::runtime::blocking::shutdown; use crate::runtime::blocking::task::BlockingTask; +use crate::runtime::builder::ThreadNameFn; +use crate::runtime::context; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{Builder, Callback, Handle}; +use slab::Slab; + use std::collections::VecDeque; use std::fmt; use std::time::Duration; @@ -30,7 +34,7 @@ struct Inner { condvar: Condvar, /// Spawned threads use this name - thread_name: String, + thread_name: ThreadNameFn, /// Spawned thread stack size stack_size: Option<usize>, @@ -41,7 +45,11 @@ struct Inner { /// Call before a thread stops before_stop: Option<Callback>, + // Maximum number of threads thread_cap: usize, + + // Customizable wait timeout + keep_alive: Duration, } struct Shared { @@ -51,6 +59,7 @@ struct Shared { num_notify: u32, shutdown: bool, shutdown_tx: Option<shutdown::Sender>, + worker_threads: Slab<thread::JoinHandle<()>>, } type Task = task::Notified<NoopSchedule>; @@ -62,11 +71,8 @@ pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, { - let rt = Handle::current(); - - let (task, handle) = task::joinable(BlockingTask::new(func)); - let _ = rt.blocking_spawner.spawn(task, &rt); - handle + let rt = context::current().expect("not currently running on the Tokio runtime."); + rt.spawn_blocking(func) } #[allow(dead_code)] @@ -74,7 +80,7 @@ pub(crate) fn try_spawn_blocking<F, R>(func: F) -> Result<(), ()> where F: FnOnce() -> R + Send + 'static, { - let rt = Handle::current(); + let rt = context::current().expect("not currently running on the Tokio runtime."); let (task, _handle) = task::joinable(BlockingTask::new(func)); rt.blocking_spawner.spawn(task, &rt) @@ -85,6 +91,7 @@ where impl BlockingPool { pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool { let (shutdown_tx, shutdown_rx) = shutdown::channel(); + let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE); BlockingPool { spawner: Spawner { @@ -96,6 +103,7 @@ impl BlockingPool { num_notify: 0, shutdown: false, shutdown_tx: Some(shutdown_tx), + worker_threads: Slab::new(), }), condvar: Condvar::new(), thread_name: builder.thread_name.clone(), @@ -103,6 +111,7 @@ impl BlockingPool { after_start: builder.after_start.clone(), before_stop: builder.before_stop.clone(), thread_cap, + keep_alive, }), }, shutdown_rx, @@ -114,7 +123,7 @@ impl BlockingPool { } pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) { - let mut shared = self.spawner.inner.shared.lock().unwrap(); + let mut shared = self.spawner.inner.shared.lock(); // The function can be called multiple times. First, by explicitly // calling `shutdown` then by the drop handler calling `shutdown`. This @@ -126,10 +135,15 @@ impl BlockingPool { shared.shutdown = true; shared.shutdown_tx = None; self.spawner.inner.condvar.notify_all(); + let mut workers = std::mem::replace(&mut shared.worker_threads, Slab::new()); drop(shared); - self.shutdown_rx.wait(timeout); + if self.shutdown_rx.wait(timeout) { + for handle in workers.drain() { + let _ = handle.join(); + } + } } } @@ -150,7 +164,7 @@ impl fmt::Debug for BlockingPool { impl Spawner { pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { let shutdown_tx = { - let mut shared = self.inner.shared.lock().unwrap(); + let mut shared = self.inner.shared.lock(); if shared.shutdown { // Shutdown the task @@ -187,14 +201,24 @@ impl Spawner { }; if let Some(shutdown_tx) = shutdown_tx { - self.spawn_thread(shutdown_tx, rt); + let mut shared = self.inner.shared.lock(); + let entry = shared.worker_threads.vacant_entry(); + + let handle = self.spawn_thread(shutdown_tx, rt, entry.key()); + + entry.insert(handle); } Ok(()) } - fn spawn_thread(&self, shutdown_tx: shutdown::Sender, rt: &Handle) { - let mut builder = thread::Builder::new().name(self.inner.thread_name.clone()); + fn spawn_thread( + &self, + shutdown_tx: shutdown::Sender, + rt: &Handle, + worker_id: usize, + ) -> thread::JoinHandle<()> { + let mut builder = thread::Builder::new().name((self.inner.thread_name)()); if let Some(stack_size) = self.inner.stack_size { builder = builder.stack_size(stack_size); @@ -205,23 +229,21 @@ impl Spawner { builder .spawn(move || { // Only the reference should be moved into the closure - let rt = &rt; - rt.enter(move || { - rt.blocking_spawner.inner.run(); - drop(shutdown_tx); - }) + let _enter = crate::runtime::context::enter(rt.clone()); + rt.blocking_spawner.inner.run(worker_id); + drop(shutdown_tx); }) - .unwrap(); + .unwrap() } } impl Inner { - fn run(&self) { + fn run(&self, worker_id: usize) { if let Some(f) = &self.after_start { f() } - let mut shared = self.shared.lock().unwrap(); + let mut shared = self.shared.lock(); 'main: loop { // BUSY @@ -229,14 +251,14 @@ impl Inner { drop(shared); task.run(); - shared = self.shared.lock().unwrap(); + shared = self.shared.lock(); } // IDLE shared.num_idle += 1; while !shared.shutdown { - let lock_result = self.condvar.wait_timeout(shared, KEEP_ALIVE).unwrap(); + let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap(); shared = lock_result.0; let timeout_result = lock_result.1; @@ -252,6 +274,8 @@ impl Inner { // Even if the condvar "timed out", if the pool is entering the // shutdown phase, we want to perform the cleanup logic. if !shared.shutdown && timeout_result.timed_out() { + shared.worker_threads.remove(worker_id); + break 'main; } @@ -264,7 +288,7 @@ impl Inner { drop(shared); task.shutdown(); - shared = self.shared.lock().unwrap(); + shared = self.shared.lock(); } // Work was produced, and we "took" it (by decrementing num_notify). diff --git a/src/runtime/blocking/shutdown.rs b/src/runtime/blocking/shutdown.rs index e76a701..3b6cc59 100644 --- a/src/runtime/blocking/shutdown.rs +++ b/src/runtime/blocking/shutdown.rs @@ -32,11 +32,13 @@ impl Receiver { /// If `timeout` is `Some`, the thread is blocked for **at most** `timeout` /// duration. If `timeout` is `None`, then the thread is blocked until the /// shutdown signal is received. - pub(crate) fn wait(&mut self, timeout: Option<Duration>) { + /// + /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`. + pub(crate) fn wait(&mut self, timeout: Option<Duration>) -> bool { use crate::runtime::enter::try_enter; if timeout == Some(Duration::from_nanos(0)) { - return; + return true; } let mut e = match try_enter(false) { @@ -44,7 +46,7 @@ impl Receiver { _ => { if std::thread::panicking() { // Don't panic in a panic - return; + return false; } else { panic!( "Cannot drop a runtime in a context where blocking is not allowed. \ @@ -60,9 +62,10 @@ impl Receiver { // current thread (usually, shutting down a runtime stored in a // thread-local). if let Some(timeout) = timeout { - let _ = e.block_on_timeout(&mut self.rx, timeout); + e.block_on_timeout(&mut self.rx, timeout).is_ok() } else { let _ = e.block_on(&mut self.rx); + true } } } diff --git a/src/runtime/builder.rs b/src/runtime/builder.rs index fad72c7..e792c7d 100644 --- a/src/runtime/builder.rs +++ b/src/runtime/builder.rs @@ -1,23 +1,24 @@ use crate::runtime::handle::Handle; -use crate::runtime::shell::Shell; -use crate::runtime::{blocking, io, time, Callback, Runtime, Spawner}; +use crate::runtime::{blocking, driver, Callback, Runtime, Spawner}; use std::fmt; -#[cfg(not(loom))] -use std::sync::Arc; +use std::io; +use std::time::Duration; /// Builds Tokio Runtime with custom configuration values. /// /// Methods can be chained in order to set the configuration values. The /// Runtime is constructed by calling [`build`]. /// -/// New instances of `Builder` are obtained via [`Builder::new`]. +/// New instances of `Builder` are obtained via [`Builder::new_multi_thread`] +/// or [`Builder::new_current_thread`]. /// /// See function level documentation for details on the various configuration /// settings. /// /// [`build`]: method@Self::build -/// [`Builder::new`]: method@Self::new +/// [`Builder::new_multi_thread`]: method@Self::new_multi_thread +/// [`Builder::new_current_thread`]: method@Self::new_current_thread /// /// # Examples /// @@ -26,9 +27,8 @@ use std::sync::Arc; /// /// fn main() { /// // build runtime -/// let runtime = Builder::new() -/// .threaded_scheduler() -/// .core_threads(4) +/// let runtime = Builder::new_multi_thread() +/// .worker_threads(4) /// .thread_name("my-custom-name") /// .thread_stack_size(3 * 1024 * 1024) /// .build() @@ -38,7 +38,7 @@ use std::sync::Arc; /// } /// ``` pub struct Builder { - /// The task execution model to use. + /// Runtime type kind: Kind, /// Whether or not to enable the I/O driver @@ -50,13 +50,13 @@ pub struct Builder { /// The number of worker threads, used by Runtime. /// /// Only used when not using the current-thread executor. - core_threads: Option<usize>, + worker_threads: Option<usize>, /// Cap on thread usage. max_threads: usize, - /// Name used for threads spawned by the runtime. - pub(super) thread_name: String, + /// Name fn used for threads spawned by the runtime. + pub(super) thread_name: ThreadNameFn, /// Stack size used for threads spawned by the runtime. pub(super) thread_stack_size: Option<usize>, @@ -66,26 +66,43 @@ pub struct Builder { /// To run before each worker thread stops pub(super) before_stop: Option<Callback>, + + /// Customizable keep alive timeout for BlockingPool + pub(super) keep_alive: Option<Duration>, } -#[derive(Debug, Clone, Copy)] -enum Kind { - Shell, - #[cfg(feature = "rt-core")] - Basic, - #[cfg(feature = "rt-threaded")] - ThreadPool, +pub(crate) type ThreadNameFn = std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>; + +pub(crate) enum Kind { + CurrentThread, + #[cfg(feature = "rt-multi-thread")] + MultiThread, } impl Builder { + /// Returns a new builder with the current thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + pub fn new_current_thread() -> Builder { + Builder::new(Kind::CurrentThread) + } + + /// Returns a new builder with the multi thread scheduler selected. + /// + /// Configuration methods can be chained on the return value. + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new_multi_thread() -> Builder { + Builder::new(Kind::MultiThread) + } + /// Returns a new runtime builder initialized with default configuration /// values. /// /// Configuration methods can be chained on the return value. - pub fn new() -> Builder { + pub(crate) fn new(kind: Kind) -> Builder { Builder { - // No task execution by default - kind: Kind::Shell, + kind, // I/O defaults to "off" enable_io: false, @@ -94,12 +111,12 @@ impl Builder { enable_time: false, // Default to lazy auto-detection (one thread per CPU core) - core_threads: None, + worker_threads: None, max_threads: 512, // Default thread name - thread_name: "tokio-runtime-worker".into(), + thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()), // Do not set a stack size by default thread_stack_size: None, @@ -107,6 +124,8 @@ impl Builder { // No worker thread callbacks after_start: None, before_stop: None, + + keep_alive: None, } } @@ -121,14 +140,13 @@ impl Builder { /// ``` /// use tokio::runtime; /// - /// let rt = runtime::Builder::new() - /// .threaded_scheduler() + /// let rt = runtime::Builder::new_multi_thread() /// .enable_all() /// .build() /// .unwrap(); /// ``` pub fn enable_all(&mut self) -> &mut Self { - #[cfg(feature = "io-driver")] + #[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))] self.enable_io(); #[cfg(feature = "time")] self.enable_time(); @@ -136,51 +154,68 @@ impl Builder { self } - #[deprecated(note = "In future will be replaced by core_threads method")] - /// Sets the maximum number of worker threads for the `Runtime`'s thread pool. + /// Sets the number of worker threads the `Runtime` will use. + /// + /// This should be a number between 0 and 32,768 though it is advised to + /// keep this value on the smaller side. /// - /// This must be a number between 1 and 32,768 though it is advised to keep - /// this value on the smaller side. + /// # Default /// /// The default value is the number of cores available to the system. - pub fn num_threads(&mut self, val: usize) -> &mut Self { - self.core_threads = Some(val); - self - } - - /// Sets the core number of worker threads for the `Runtime`'s thread pool. /// - /// This should be a number between 1 and 32,768 though it is advised to keep - /// this value on the smaller side. + /// # Panic /// - /// The default value is the number of cores available to the system. + /// When using the `current_thread` runtime this method will panic, since + /// those variants do not allow setting worker thread counts. /// - /// These threads will be always active and running. /// /// # Examples /// + /// ## Multi threaded runtime with 4 threads + /// + /// ``` + /// use tokio::runtime; + /// + /// // This will spawn a work-stealing runtime with 4 worker threads. + /// let rt = runtime::Builder::new_multi_thread() + /// .worker_threads(4) + /// .build() + /// .unwrap(); + /// + /// rt.spawn(async move {}); + /// ``` + /// + /// ## Current thread runtime (will only run on the current thread via `Runtime::block_on`) + /// /// ``` /// use tokio::runtime; /// - /// let rt = runtime::Builder::new() - /// .threaded_scheduler() - /// .core_threads(4) + /// // Create a runtime that _must_ be driven from a call + /// // to `Runtime::block_on`. + /// let rt = runtime::Builder::new_current_thread() /// .build() /// .unwrap(); + /// + /// // This will run the runtime and future on the current thread + /// rt.block_on(async move {}); /// ``` - pub fn core_threads(&mut self, val: usize) -> &mut Self { - assert_ne!(val, 0, "Core threads cannot be zero"); - self.core_threads = Some(val); + /// + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + pub fn worker_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Worker threads cannot be set to 0"); + self.worker_threads = Some(val); self } /// Specifies limit for threads, spawned by the Runtime. /// /// This is number of threads to be used by Runtime, including `core_threads` - /// Having `max_threads` less than `core_threads` results in invalid configuration + /// Having `max_threads` less than `worker_threads` results in invalid configuration /// when building multi-threaded `Runtime`, which would cause a panic. /// - /// Similarly to the `core_threads`, this number should be between 1 and 32,768. + /// Similarly to the `worker_threads`, this number should be between 0 and 32,768. /// /// The default value is 512. /// @@ -189,7 +224,6 @@ impl Builder { /// Otherwise as `core_threads` are always active, it limits additional threads (e.g. for /// blocking annotations) as `max_threads - core_threads`. pub fn max_threads(&mut self, val: usize) -> &mut Self { - assert_ne!(val, 0, "Thread limit cannot be zero"); self.max_threads = val; self } @@ -204,13 +238,42 @@ impl Builder { /// # use tokio::runtime; /// /// # pub fn main() { - /// let rt = runtime::Builder::new() + /// let rt = runtime::Builder::new_multi_thread() /// .thread_name("my-pool") /// .build(); /// # } /// ``` pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self { - self.thread_name = val.into(); + let val = val.into(); + self.thread_name = std::sync::Arc::new(move || val.clone()); + self + } + + /// Sets a function used to generate the name of threads spawned by the `Runtime`'s thread pool. + /// + /// The default name fn is `|| "tokio-runtime-worker".into()`. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// # use std::sync::atomic::{AtomicUsize, Ordering}; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_name_fn(|| { + /// static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + /// let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + /// format!("my-pool-{}", id) + /// }) + /// .build(); + /// # } + /// ``` + pub fn thread_name_fn<F>(&mut self, f: F) -> &mut Self + where + F: Fn() -> String + Send + Sync + 'static, + { + self.thread_name = std::sync::Arc::new(f); self } @@ -228,8 +291,7 @@ impl Builder { /// # use tokio::runtime; /// /// # pub fn main() { - /// let rt = runtime::Builder::new() - /// .threaded_scheduler() + /// let rt = runtime::Builder::new_multi_thread() /// .thread_stack_size(32 * 1024) /// .build(); /// # } @@ -250,8 +312,7 @@ impl Builder { /// # use tokio::runtime; /// /// # pub fn main() { - /// let runtime = runtime::Builder::new() - /// .threaded_scheduler() + /// let runtime = runtime::Builder::new_multi_thread() /// .on_thread_start(|| { /// println!("thread started"); /// }) @@ -263,7 +324,7 @@ impl Builder { where F: Fn() + Send + Sync + 'static, { - self.after_start = Some(Arc::new(f)); + self.after_start = Some(std::sync::Arc::new(f)); self } @@ -277,8 +338,7 @@ impl Builder { /// # use tokio::runtime; /// /// # pub fn main() { - /// let runtime = runtime::Builder::new() - /// .threaded_scheduler() + /// let runtime = runtime::Builder::new_multi_thread() /// .on_thread_stop(|| { /// println!("thread stopping"); /// }) @@ -290,56 +350,86 @@ impl Builder { where F: Fn() + Send + Sync + 'static, { - self.before_stop = Some(Arc::new(f)); + self.before_stop = Some(std::sync::Arc::new(f)); self } /// Creates the configured `Runtime`. /// - /// The returned `ThreadPool` instance is ready to spawn tasks. + /// The returned `Runtime` instance is ready to spawn tasks. /// /// # Examples /// /// ``` /// use tokio::runtime::Builder; /// - /// let mut rt = Builder::new().build().unwrap(); + /// let rt = Builder::new_multi_thread().build().unwrap(); /// /// rt.block_on(async { /// println!("Hello from the Tokio runtime"); /// }); /// ``` pub fn build(&mut self) -> io::Result<Runtime> { - match self.kind { - Kind::Shell => self.build_shell_runtime(), - #[cfg(feature = "rt-core")] - Kind::Basic => self.build_basic_runtime(), - #[cfg(feature = "rt-threaded")] - Kind::ThreadPool => self.build_threaded_runtime(), + match &self.kind { + Kind::CurrentThread => self.build_basic_runtime(), + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => self.build_threaded_runtime(), } } - fn build_shell_runtime(&mut self) -> io::Result<Runtime> { - use crate::runtime::Kind; + fn get_cfg(&self) -> driver::Cfg { + driver::Cfg { + enable_io: self.enable_io, + enable_time: self.enable_time, + } + } - let clock = time::create_clock(); + /// Sets a custom timeout for a thread in the blocking pool. + /// + /// By default, the timeout for a thread is set to 10 seconds. This can + /// be overriden using .thread_keep_alive(). + /// + /// # Example + /// + /// ``` + /// # use tokio::runtime; + /// # use std::time::Duration; + /// + /// # pub fn main() { + /// let rt = runtime::Builder::new_multi_thread() + /// .thread_keep_alive(Duration::from_millis(100)) + /// .build(); + /// # } + /// ``` + pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self { + self.keep_alive = Some(duration); + self + } + + fn build_basic_runtime(&mut self) -> io::Result<Runtime> { + use crate::runtime::{BasicScheduler, Kind}; - // Create I/O driver - let (io_driver, io_handle) = io::create_driver(self.enable_io)?; - let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone()); + let (driver, resources) = driver::Driver::new(self.get_cfg())?; - let spawner = Spawner::Shell; + // And now put a single-threaded scheduler on top of the timer. When + // there are no futures ready to do something, it'll let the timer or + // the reactor to generate some new stimuli for the futures to continue + // in their life. + let scheduler = BasicScheduler::new(driver); + let spawner = Spawner::Basic(scheduler.spawner().clone()); + // Blocking pool let blocking_pool = blocking::create_blocking_pool(self, self.max_threads); let blocking_spawner = blocking_pool.spawner().clone(); Ok(Runtime { - kind: Kind::Shell(Shell::new(driver)), + kind: Kind::CurrentThread(scheduler), handle: Handle { spawner, - io_handle, - time_handle, - clock, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, blocking_spawner, }, blocking_pool, @@ -359,7 +449,7 @@ cfg_io_driver! { /// ``` /// use tokio::runtime; /// - /// let rt = runtime::Builder::new() + /// let rt = runtime::Builder::new_multi_thread() /// .enable_io() /// .build() /// .unwrap(); @@ -382,7 +472,7 @@ cfg_time! { /// ``` /// use tokio::runtime; /// - /// let rt = runtime::Builder::new() + /// let rt = runtime::Builder::new_multi_thread() /// .enable_time() /// .build() /// .unwrap(); @@ -394,85 +484,19 @@ cfg_time! { } } -cfg_rt_core! { +cfg_rt_multi_thread! { impl Builder { - /// Sets runtime to use a simpler scheduler that runs all tasks on the current-thread. - /// - /// The executor and all necessary drivers will all be run on the current - /// thread during [`block_on`] calls. - /// - /// See also [the module level documentation][1], which has a section on scheduler - /// types. - /// - /// [1]: index.html#runtime-configurations - /// [`block_on`]: Runtime::block_on - pub fn basic_scheduler(&mut self) -> &mut Self { - self.kind = Kind::Basic; - self - } - - fn build_basic_runtime(&mut self) -> io::Result<Runtime> { - use crate::runtime::{BasicScheduler, Kind}; - - let clock = time::create_clock(); - - // Create I/O driver - let (io_driver, io_handle) = io::create_driver(self.enable_io)?; - - let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone()); - - // And now put a single-threaded scheduler on top of the timer. When - // there are no futures ready to do something, it'll let the timer or - // the reactor to generate some new stimuli for the futures to continue - // in their life. - let scheduler = BasicScheduler::new(driver); - let spawner = Spawner::Basic(scheduler.spawner().clone()); - - // Blocking pool - let blocking_pool = blocking::create_blocking_pool(self, self.max_threads); - let blocking_spawner = blocking_pool.spawner().clone(); - - Ok(Runtime { - kind: Kind::Basic(scheduler), - handle: Handle { - spawner, - io_handle, - time_handle, - clock, - blocking_spawner, - }, - blocking_pool, - }) - } - } -} - -cfg_rt_threaded! { - impl Builder { - /// Sets runtime to use a multi-threaded scheduler for executing tasks. - /// - /// See also [the module level documentation][1], which has a section on scheduler - /// types. - /// - /// [1]: index.html#runtime-configurations - pub fn threaded_scheduler(&mut self) -> &mut Self { - self.kind = Kind::ThreadPool; - self - } - fn build_threaded_runtime(&mut self) -> io::Result<Runtime> { use crate::loom::sys::num_cpus; use crate::runtime::{Kind, ThreadPool}; use crate::runtime::park::Parker; use std::cmp; - let core_threads = self.core_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus())); + let core_threads = self.worker_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus())); assert!(core_threads <= self.max_threads, "Core threads number cannot be above max limit"); - let clock = time::create_clock(); + let (driver, resources) = driver::Driver::new(self.get_cfg())?; - let (io_driver, io_handle) = io::create_driver(self.enable_io)?; - let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone()); let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver)); let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); @@ -483,14 +507,16 @@ cfg_rt_threaded! { // Create the runtime handle let handle = Handle { spawner, - io_handle, - time_handle, - clock, + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, blocking_spawner, }; // Spawn the thread pool workers - handle.enter(|| launch.launch()); + let _enter = crate::runtime::context::enter(handle.clone()); + launch.launch(); Ok(Runtime { kind: Kind::ThreadPool(scheduler), @@ -501,19 +527,15 @@ cfg_rt_threaded! { } } -impl Default for Builder { - fn default() -> Self { - Self::new() - } -} - impl fmt::Debug for Builder { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Builder") - .field("kind", &self.kind) - .field("core_threads", &self.core_threads) + .field("worker_threads", &self.worker_threads) .field("max_threads", &self.max_threads) - .field("thread_name", &self.thread_name) + .field( + "thread_name", + &"<dyn Fn() -> String + Send + Sync + 'static>", + ) .field("thread_stack_size", &self.thread_stack_size) .field("after_start", &self.after_start.as_ref().map(|_| "...")) .field("before_stop", &self.after_start.as_ref().map(|_| "...")) diff --git a/src/runtime/context.rs b/src/runtime/context.rs index 1b267f4..0817019 100644 --- a/src/runtime/context.rs +++ b/src/runtime/context.rs @@ -12,7 +12,7 @@ pub(crate) fn current() -> Option<Handle> { } cfg_io_driver! { - pub(crate) fn io_handle() -> crate::runtime::io::Handle { + pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle { CONTEXT.with(|ctx| match *ctx.borrow() { Some(ref ctx) => ctx.io_handle.clone(), None => Default::default(), @@ -20,8 +20,18 @@ cfg_io_driver! { } } +cfg_signal_internal! { + #[cfg(unix)] + pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle { + CONTEXT.with(|ctx| match *ctx.borrow() { + Some(ref ctx) => ctx.signal_handle.clone(), + None => Default::default(), + }) + } +} + cfg_time! { - pub(crate) fn time_handle() -> crate::runtime::time::Handle { + pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle { CONTEXT.with(|ctx| match *ctx.borrow() { Some(ref ctx) => ctx.time_handle.clone(), None => Default::default(), @@ -29,7 +39,7 @@ cfg_time! { } cfg_test_util! { - pub(crate) fn clock() -> Option<crate::runtime::time::Clock> { + pub(crate) fn clock() -> Option<crate::runtime::driver::Clock> { CONTEXT.with(|ctx| match *ctx.borrow() { Some(ref ctx) => Some(ctx.clock.clone()), None => None, @@ -38,7 +48,7 @@ cfg_time! { } } -cfg_rt_core! { +cfg_rt! { pub(crate) fn spawn_handle() -> Option<crate::runtime::Spawner> { CONTEXT.with(|ctx| match *ctx.borrow() { Some(ref ctx) => Some(ctx.spawner.clone()), @@ -50,24 +60,20 @@ cfg_rt_core! { /// Set this [`Handle`] as the current active [`Handle`]. /// /// [`Handle`]: Handle -pub(crate) fn enter<F, R>(new: Handle, f: F) -> R -where - F: FnOnce() -> R, -{ - struct DropGuard(Option<Handle>); - - impl Drop for DropGuard { - fn drop(&mut self) { - CONTEXT.with(|ctx| { - *ctx.borrow_mut() = self.0.take(); - }); - } - } - - let _guard = CONTEXT.with(|ctx| { +pub(crate) fn enter(new: Handle) -> EnterGuard { + CONTEXT.with(|ctx| { let old = ctx.borrow_mut().replace(new); - DropGuard(old) - }); + EnterGuard(old) + }) +} + +#[derive(Debug)] +pub(crate) struct EnterGuard(Option<Handle>); - f() +impl Drop for EnterGuard { + fn drop(&mut self) { + CONTEXT.with(|ctx| { + *ctx.borrow_mut() = self.0.take(); + }); + } } diff --git a/src/runtime/driver.rs b/src/runtime/driver.rs new file mode 100644 index 0000000..e89de9d --- /dev/null +++ b/src/runtime/driver.rs @@ -0,0 +1,205 @@ +//! Abstracts out the entire chain of runtime sub-drivers into common types. +use crate::park::thread::ParkThread; +use crate::park::Park; + +use std::io; +use std::time::Duration; + +// ===== io driver ===== + +cfg_io_driver! { + type IoDriver = crate::io::driver::Driver; + type IoStack = crate::park::either::Either<ProcessDriver, ParkThread>; + pub(crate) type IoHandle = Option<crate::io::driver::Handle>; + + fn create_io_stack(enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + use crate::park::either::Either; + + #[cfg(loom)] + assert!(!enabled); + + let ret = if enabled { + let io_driver = crate::io::driver::Driver::new()?; + let io_handle = io_driver.handle(); + + let (signal_driver, signal_handle) = create_signal_driver(io_driver)?; + let process_driver = create_process_driver(signal_driver)?; + + (Either::A(process_driver), Some(io_handle), signal_handle) + } else { + (Either::B(ParkThread::new()), Default::default(), Default::default()) + }; + + Ok(ret) + } +} + +cfg_not_io_driver! { + pub(crate) type IoHandle = (); + type IoStack = ParkThread; + + fn create_io_stack(_enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + Ok((ParkThread::new(), Default::default(), Default::default())) + } +} + +// ===== signal driver ===== + +macro_rules! cfg_signal_internal_and_unix { + ($($item:item)*) => { + #[cfg(unix)] + cfg_signal_internal! { $($item)* } + } +} + +cfg_signal_internal_and_unix! { + type SignalDriver = crate::signal::unix::driver::Driver; + pub(crate) type SignalHandle = Option<crate::signal::unix::driver::Handle>; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + let driver = crate::signal::unix::driver::Driver::new(io_driver)?; + let handle = driver.handle(); + Ok((driver, Some(handle))) + } +} + +cfg_not_signal_internal! { + pub(crate) type SignalHandle = (); + + cfg_io_driver! { + type SignalDriver = IoDriver; + + fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> { + Ok((io_driver, ())) + } + } +} + +// ===== process driver ===== + +cfg_process_driver! { + type ProcessDriver = crate::process::unix::driver::Driver; + + fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> { + crate::process::unix::driver::Driver::new(signal_driver) + } +} + +cfg_not_process_driver! { + cfg_io_driver! { + type ProcessDriver = SignalDriver; + + fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> { + Ok(signal_driver) + } + } +} + +// ===== time driver ===== + +cfg_time! { + type TimeDriver = crate::park::either::Either<crate::time::driver::Driver<IoStack>, IoStack>; + + pub(crate) type Clock = crate::time::Clock; + pub(crate) type TimeHandle = Option<crate::time::driver::Handle>; + + fn create_clock() -> Clock { + crate::time::Clock::new() + } + + fn create_time_driver( + enable: bool, + io_stack: IoStack, + clock: Clock, + ) -> (TimeDriver, TimeHandle) { + use crate::park::either::Either; + + if enable { + let driver = crate::time::driver::Driver::new(io_stack, clock); + let handle = driver.handle(); + + (Either::A(driver), Some(handle)) + } else { + (Either::B(io_stack), None) + } + } +} + +cfg_not_time! { + type TimeDriver = IoStack; + + pub(crate) type Clock = (); + pub(crate) type TimeHandle = (); + + fn create_clock() -> Clock { + () + } + + fn create_time_driver( + _enable: bool, + io_stack: IoStack, + _clock: Clock, + ) -> (TimeDriver, TimeHandle) { + (io_stack, ()) + } +} + +// ===== runtime driver ===== + +#[derive(Debug)] +pub(crate) struct Driver { + inner: TimeDriver, +} + +pub(crate) struct Resources { + pub(crate) io_handle: IoHandle, + pub(crate) signal_handle: SignalHandle, + pub(crate) time_handle: TimeHandle, + pub(crate) clock: Clock, +} + +pub(crate) struct Cfg { + pub(crate) enable_io: bool, + pub(crate) enable_time: bool, +} + +impl Driver { + pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Resources)> { + let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io)?; + + let clock = create_clock(); + let (time_driver, time_handle) = + create_time_driver(cfg.enable_time, io_stack, clock.clone()); + + Ok(( + Self { inner: time_driver }, + Resources { + io_handle, + signal_handle, + time_handle, + clock, + }, + )) + } +} + +impl Park for Driver { + type Unpark = <TimeDriver as Park>::Unpark; + type Error = <TimeDriver as Park>::Error; + + fn unpark(&self) -> Self::Unpark { + self.inner.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.inner.park() + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.inner.park_timeout(duration) + } + + fn shutdown(&mut self) { + self.inner.shutdown() + } +} diff --git a/src/runtime/enter.rs b/src/runtime/enter.rs index 56a7c57..4dd8dd0 100644 --- a/src/runtime/enter.rs +++ b/src/runtime/enter.rs @@ -4,8 +4,8 @@ use std::marker::PhantomData; #[derive(Debug, Clone, Copy)] pub(crate) enum EnterContext { + #[cfg_attr(not(feature = "rt"), allow(dead_code))] Entered { - #[allow(dead_code)] allow_blocking: bool, }, NotEntered, @@ -13,11 +13,7 @@ pub(crate) enum EnterContext { impl EnterContext { pub(crate) fn is_entered(self) -> bool { - if let EnterContext::Entered { .. } = self { - true - } else { - false - } + matches!(self, EnterContext::Entered { .. }) } } @@ -28,32 +24,38 @@ pub(crate) struct Enter { _p: PhantomData<RefCell<()>>, } -/// Marks the current thread as being within the dynamic extent of an -/// executor. -pub(crate) fn enter(allow_blocking: bool) -> Enter { - if let Some(enter) = try_enter(allow_blocking) { - return enter; - } +cfg_rt! { + use crate::park::thread::ParkError; - panic!( - "Cannot start a runtime from within a runtime. This happens \ - because a function (like `block_on`) attempted to block the \ - current thread while the thread is being used to drive \ - asynchronous tasks." - ); -} + use std::time::Duration; -/// Tries to enter a runtime context, returns `None` if already in a runtime -/// context. -pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> { - ENTERED.with(|c| { - if c.get().is_entered() { - None - } else { - c.set(EnterContext::Entered { allow_blocking }); - Some(Enter { _p: PhantomData }) + /// Marks the current thread as being within the dynamic extent of an + /// executor. + pub(crate) fn enter(allow_blocking: bool) -> Enter { + if let Some(enter) = try_enter(allow_blocking) { + return enter; } - }) + + panic!( + "Cannot start a runtime from within a runtime. This happens \ + because a function (like `block_on`) attempted to block the \ + current thread while the thread is being used to drive \ + asynchronous tasks." + ); + } + + /// Tries to enter a runtime context, returns `None` if already in a runtime + /// context. + pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> { + ENTERED.with(|c| { + if c.get().is_entered() { + None + } else { + c.set(EnterContext::Entered { allow_blocking }); + Some(Enter { _p: PhantomData }) + } + }) + } } // Forces the current "entered" state to be cleared while the closure @@ -63,115 +65,92 @@ pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> { // // This is hidden for a reason. Do not use without fully understanding // executors. Misuing can easily cause your program to deadlock. -#[cfg(all(feature = "rt-threaded", feature = "blocking"))] -pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { - // Reset in case the closure panics - struct Reset(EnterContext); - impl Drop for Reset { - fn drop(&mut self) { - ENTERED.with(|c| { - assert!(!c.get().is_entered(), "closure claimed permanent executor"); - c.set(self.0); - }); +cfg_rt_multi_thread! { + pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { + // Reset in case the closure panics + struct Reset(EnterContext); + impl Drop for Reset { + fn drop(&mut self) { + ENTERED.with(|c| { + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); + }); + } } - } - let was = ENTERED.with(|c| { - let e = c.get(); - assert!(e.is_entered(), "asked to exit when not entered"); - c.set(EnterContext::NotEntered); - e - }); + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterContext::NotEntered); + e + }); - let _reset = Reset(was); - // dropping _reset after f() will reset ENTERED - f() + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED + f() + } } -cfg_rt_core! { - cfg_rt_util! { - /// Disallow blocking in the current runtime context until the guard is dropped. - pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { - let reset = ENTERED.with(|c| { - if let EnterContext::Entered { - allow_blocking: true, - } = c.get() - { - c.set(EnterContext::Entered { - allow_blocking: false, - }); - true - } else { - false - } - }); - DisallowBlockingGuard(reset) - } +cfg_rt! { + /// Disallow blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { + let reset = ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: true, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: false, + }); + true + } else { + false + } + }); + DisallowBlockingGuard(reset) + } - pub(crate) struct DisallowBlockingGuard(bool); - impl Drop for DisallowBlockingGuard { - fn drop(&mut self) { - if self.0 { - // XXX: Do we want some kind of assertion here, or is "best effort" okay? - ENTERED.with(|c| { - if let EnterContext::Entered { - allow_blocking: false, - } = c.get() - { - c.set(EnterContext::Entered { - allow_blocking: true, - }); - } - }) - } + pub(crate) struct DisallowBlockingGuard(bool); + impl Drop for DisallowBlockingGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: false, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: true, + }); + } + }) } } } } -cfg_rt_threaded! { - cfg_blocking! { - /// Returns true if in a runtime context. - pub(crate) fn context() -> EnterContext { - ENTERED.with(|c| c.get()) - } +cfg_rt_multi_thread! { + /// Returns true if in a runtime context. + pub(crate) fn context() -> EnterContext { + ENTERED.with(|c| c.get()) } } -cfg_block_on! { +cfg_rt! { impl Enter { /// Blocks the thread on the specified future, returning the value with /// which that future completes. - pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, crate::park::ParkError> + pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError> where F: std::future::Future, { - use crate::park::{CachedParkThread, Park}; - use std::task::Context; - use std::task::Poll::Ready; + use crate::park::thread::CachedParkThread; let mut park = CachedParkThread::new(); - let waker = park.get_unpark()?.into_waker(); - let mut cx = Context::from_waker(&waker); - - pin!(f); - - loop { - if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { - return Ok(v); - } - - park.park()?; - } + park.block_on(f) } - } -} -cfg_blocking_impl! { - use crate::park::ParkError; - use std::time::Duration; - - impl Enter { /// Blocks the thread on the specified future for **at most** `timeout` /// /// If the future completes before `timeout`, the result is returned. If @@ -180,7 +159,8 @@ cfg_blocking_impl! { where F: std::future::Future, { - use crate::park::{CachedParkThread, Park}; + use crate::park::Park; + use crate::park::thread::CachedParkThread; use std::task::Context; use std::task::Poll::Ready; use std::time::Instant; diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs index 0716a7f..b1e8d8f 100644 --- a/src/runtime/handle.rs +++ b/src/runtime/handle.rs @@ -1,16 +1,6 @@ -use crate::runtime::{blocking, context, io, time, Spawner}; -use std::{error, fmt}; - -cfg_blocking! { - use crate::runtime::task; - use crate::runtime::blocking::task::BlockingTask; -} - -cfg_rt_core! { - use crate::task::JoinHandle; - - use std::future::Future; -} +use crate::runtime::blocking::task::BlockingTask; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{blocking, driver, Spawner}; /// Handle to the runtime. /// @@ -19,353 +9,56 @@ cfg_rt_core! { /// /// [`Runtime::handle`]: crate::runtime::Runtime::handle() #[derive(Debug, Clone)] -pub struct Handle { +pub(crate) struct Handle { pub(super) spawner: Spawner, /// Handles to the I/O drivers - pub(super) io_handle: io::Handle, + pub(super) io_handle: driver::IoHandle, + + /// Handles to the signal drivers + pub(super) signal_handle: driver::SignalHandle, /// Handles to the time drivers - pub(super) time_handle: time::Handle, + pub(super) time_handle: driver::TimeHandle, /// Source of `Instant::now()` - pub(super) clock: time::Clock, + pub(super) clock: driver::Clock, /// Blocking pool spawner pub(super) blocking_spawner: blocking::Spawner, } impl Handle { - /// Enter the runtime context. This allows you to construct types that must - /// have an executor available on creation such as [`Delay`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`]. - /// - /// This function is also available as [`Runtime::enter`]. - /// - /// [`Delay`]: struct@crate::time::Delay - /// [`TcpStream`]: struct@crate::net::TcpStream - /// [`Runtime::enter`]: fn@crate::runtime::Runtime::enter - /// [`tokio::spawn`]: fn@crate::spawn - /// - /// # Example - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// fn function_that_spawns(msg: String) { - /// // Had we not used `handle.enter` below, this would panic. - /// tokio::spawn(async move { - /// println!("{}", msg); - /// }); - /// } - /// - /// fn main() { - /// let rt = Runtime::new().unwrap(); - /// let handle = rt.handle().clone(); - /// - /// let s = "Hello World!".to_string(); - /// - /// // By entering the context, we tie `tokio::spawn` to this executor. - /// handle.enter(|| function_that_spawns(s)); - /// } - /// ``` - pub fn enter<F, R>(&self, f: F) -> R + // /// Enter the runtime context. This allows you to construct types that must + // /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. + // /// It will also allow you to call methods such as [`tokio::spawn`]. + // pub(crate) fn enter<F, R>(&self, f: F) -> R + // where + // F: FnOnce() -> R, + // { + // context::enter(self.clone(), f) + // } + + /// Run the provided function on an executor dedicated to blocking operations. + pub(crate) fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> where - F: FnOnce() -> R, + F: FnOnce() -> R + Send + 'static, { - context::enter(self.clone(), f) - } - - /// Returns a `Handle` view over the currently running `Runtime` - /// - /// # Panic - /// - /// This will panic if called outside the context of a Tokio runtime. That means that you must - /// call this on one of the threads **being run by the runtime**. Calling this from within a - /// thread created by `std::thread::spawn` (for example) will cause a panic. - /// - /// # Examples - /// - /// This can be used to obtain the handle of the surrounding runtime from an async - /// block or function running on that runtime. - /// - /// ``` - /// # use std::thread; - /// # use tokio::runtime::Runtime; - /// # fn dox() { - /// # let rt = Runtime::new().unwrap(); - /// # rt.spawn(async { - /// use tokio::runtime::Handle; - /// - /// // Inside an async block or function. - /// let handle = Handle::current(); - /// handle.spawn(async { - /// println!("now running in the existing Runtime"); - /// }); - /// - /// # let handle = - /// thread::spawn(move || { - /// // Notice that the handle is created outside of this thread and then moved in - /// handle.block_on(async { /* ... */ }) - /// // This next line would cause a panic - /// // let handle2 = Handle::current(); - /// }); - /// # handle.join().unwrap(); - /// # }); - /// # } - /// ``` - pub fn current() -> Self { - context::current().expect("not currently running on the Tokio runtime.") - } - - /// Returns a Handle view over the currently running Runtime - /// - /// Returns an error if no Runtime has been started - /// - /// Contrary to `current`, this never panics - pub fn try_current() -> Result<Self, TryCurrentError> { - context::current().ok_or(TryCurrentError(())) + #[cfg(feature = "tracing")] + let func = { + let span = tracing::trace_span!( + target: "tokio::task", + "task", + kind = %"blocking", + function = %std::any::type_name::<F>(), + ); + move || { + let _g = span.enter(); + func() + } + }; + let (task, handle) = task::joinable(BlockingTask::new(func)); + let _ = self.blocking_spawner.spawn(task, &self); + handle } } - -cfg_rt_core! { - impl Handle { - /// Spawns a future onto the Tokio runtime. - /// - /// This spawns the given future onto the runtime's executor, usually a - /// thread pool. The thread pool is then responsible for polling the future - /// until it completes. - /// - /// See [module level][mod] documentation for more details. - /// - /// [mod]: index.html - /// - /// # Examples - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// # fn dox() { - /// // Create the runtime - /// let rt = Runtime::new().unwrap(); - /// let handle = rt.handle(); - /// - /// // Spawn a future onto the runtime - /// handle.spawn(async { - /// println!("now running on a worker thread"); - /// }); - /// # } - /// ``` - /// - /// # Panics - /// - /// This function will not panic unless task execution is disabled on the - /// executor. This can only happen if the runtime was built using - /// [`Builder`] without picking either [`basic_scheduler`] or - /// [`threaded_scheduler`]. - /// - /// [`Builder`]: struct@crate::runtime::Builder - /// [`threaded_scheduler`]: fn@crate::runtime::Builder::threaded_scheduler - /// [`basic_scheduler`]: fn@crate::runtime::Builder::basic_scheduler - pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - self.spawner.spawn(future) - } - - /// Run a future to completion on the Tokio runtime from a synchronous - /// context. - /// - /// This runs the given future on the runtime, blocking until it is - /// complete, and yielding its resolved result. Any tasks or timers which - /// the future spawns internally will be executed on the runtime. - /// - /// If the provided executor currently has no active core thread, this - /// function might hang until a core thread is added. This is not a - /// concern when using the [threaded scheduler], as it always has active - /// core threads, but if you use the [basic scheduler], some other - /// thread must currently be inside a call to [`Runtime::block_on`]. - /// See also [the module level documentation][1], which has a section on - /// scheduler types. - /// - /// This method may not be called from an asynchronous context. - /// - /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler - /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler - /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on - /// [1]: index.html#runtime-configurations - /// - /// # Panics - /// - /// This function panics if the provided future panics, or if called - /// within an asynchronous execution context. - /// - /// # Examples - /// - /// Using `block_on` with the [threaded scheduler]. - /// - /// ``` - /// use tokio::runtime::Runtime; - /// use std::thread; - /// - /// // Create the runtime. - /// // - /// // If the rt-threaded feature is enabled, this creates a threaded - /// // scheduler by default. - /// let rt = Runtime::new().unwrap(); - /// let handle = rt.handle().clone(); - /// - /// // Use the runtime from another thread. - /// let th = thread::spawn(move || { - /// // Execute the future, blocking the current thread until completion. - /// // - /// // This example uses the threaded scheduler, so no concurrent call to - /// // `rt.block_on` is required. - /// handle.block_on(async { - /// println!("hello"); - /// }); - /// }); - /// - /// th.join().unwrap(); - /// ``` - /// - /// Using the [basic scheduler] requires a concurrent call to - /// [`Runtime::block_on`]: - /// - /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler - /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler - /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on - /// - /// ``` - /// use tokio::runtime::Builder; - /// use tokio::sync::oneshot; - /// use std::thread; - /// - /// // Create the runtime. - /// let mut rt = Builder::new() - /// .enable_all() - /// .basic_scheduler() - /// .build() - /// .unwrap(); - /// - /// let handle = rt.handle().clone(); - /// - /// // Signal main thread when task has finished. - /// let (send, recv) = oneshot::channel(); - /// - /// // Use the runtime from another thread. - /// let th = thread::spawn(move || { - /// // Execute the future, blocking the current thread until completion. - /// handle.block_on(async { - /// send.send("done").unwrap(); - /// }); - /// }); - /// - /// // The basic scheduler is used, so the thread above might hang if we - /// // didn't call block_on on the rt too. - /// rt.block_on(async { - /// assert_eq!(recv.await.unwrap(), "done"); - /// }); - /// # th.join().unwrap(); - /// ``` - /// - pub fn block_on<F: Future>(&self, future: F) -> F::Output { - self.enter(|| { - let mut enter = crate::runtime::enter(true); - enter.block_on(future).expect("failed to park thread") - }) - } - } -} - -cfg_blocking! { - impl Handle { - /// Runs the provided closure on a thread where blocking is acceptable. - /// - /// In general, issuing a blocking call or performing a lot of compute in a - /// future without yielding is not okay, as it may prevent the executor from - /// driving other futures forward. This function runs the provided closure - /// on a thread dedicated to blocking operations. See the [CPU-bound tasks - /// and blocking code][blocking] section for more information. - /// - /// Tokio will spawn more blocking threads when they are requested through - /// this function until the upper limit configured on the [`Builder`] is - /// reached. This limit is very large by default, because `spawn_blocking` is - /// often used for various kinds of IO operations that cannot be performed - /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you - /// should keep this large upper limit in mind; to run your CPU-bound - /// computations on only a few threads, you should use a separate thread - /// pool such as [rayon] rather than configuring the number of blocking - /// threads. - /// - /// This function is intended for non-async operations that eventually - /// finish on their own. If you want to spawn an ordinary thread, you should - /// use [`thread::spawn`] instead. - /// - /// Closures spawned using `spawn_blocking` cannot be cancelled. When you - /// shut down the executor, it will wait indefinitely for all blocking - /// operations to finish. You can use [`shutdown_timeout`] to stop waiting - /// for them after a certain timeout. Be aware that this will still not - /// cancel the tasks — they are simply allowed to keep running after the - /// method returns. - /// - /// Note that if you are using the [basic scheduler], this function will - /// still spawn additional threads for blocking operations. The basic - /// scheduler's single thread is only used for asynchronous code. - /// - /// [`Builder`]: struct@crate::runtime::Builder - /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code - /// [rayon]: https://docs.rs/rayon - /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler - /// [`thread::spawn`]: fn@std::thread::spawn - /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout - /// - /// # Examples - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ - /// // Create the runtime - /// let rt = Runtime::new().unwrap(); - /// let handle = rt.handle(); - /// - /// let res = handle.spawn_blocking(move || { - /// // do some compute-heavy work or call synchronous code - /// "done computing" - /// }).await?; - /// - /// assert_eq!(res, "done computing"); - /// # Ok(()) - /// # } - /// ``` - pub fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R> - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - let (task, handle) = task::joinable(BlockingTask::new(f)); - let _ = self.blocking_spawner.spawn(task, self); - handle - } - } -} - -/// Error returned by `try_current` when no Runtime has been started -pub struct TryCurrentError(()); - -impl fmt::Debug for TryCurrentError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TryCurrentError").finish() - } -} - -impl fmt::Display for TryCurrentError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("no tokio Runtime has been initialized") - } -} - -impl error::Error for TryCurrentError {} diff --git a/src/runtime/io.rs b/src/runtime/io.rs deleted file mode 100644 index 6a0953a..0000000 --- a/src/runtime/io.rs +++ /dev/null @@ -1,63 +0,0 @@ -//! Abstracts out the APIs necessary to `Runtime` for integrating the I/O -//! driver. When the `time` feature flag is **not** enabled. These APIs are -//! shells. This isolates the complexity of dealing with conditional -//! compilation. - -/// Re-exported for convenience. -pub(crate) use std::io::Result; - -pub(crate) use variant::*; - -#[cfg(feature = "io-driver")] -mod variant { - use crate::io::driver; - use crate::park::{Either, ParkThread}; - - use std::io; - - /// The driver value the runtime passes to the `timer` layer. - /// - /// When the `io-driver` feature is enabled, this is the "real" I/O driver - /// backed by Mio. Without the `io-driver` feature, this is a thread parker - /// backed by a condition variable. - pub(crate) type Driver = Either<driver::Driver, ParkThread>; - - /// The handle the runtime stores for future use. - /// - /// When the `io-driver` feature is **not** enabled, this is `()`. - pub(crate) type Handle = Option<driver::Handle>; - - pub(crate) fn create_driver(enable: bool) -> io::Result<(Driver, Handle)> { - #[cfg(loom)] - assert!(!enable); - - if enable { - let driver = driver::Driver::new()?; - let handle = driver.handle(); - - Ok((Either::A(driver), Some(handle))) - } else { - let driver = ParkThread::new(); - Ok((Either::B(driver), None)) - } - } -} - -#[cfg(not(feature = "io-driver"))] -mod variant { - use crate::park::ParkThread; - - use std::io; - - /// I/O is not enabled, use a condition variable based parker - pub(crate) type Driver = ParkThread; - - /// There is no handle - pub(crate) type Handle = (); - - pub(crate) fn create_driver(_enable: bool) -> io::Result<(Driver, Handle)> { - let driver = ParkThread::new(); - - Ok((driver, ())) - } -} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 300a146..be4aa38 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,8 +1,7 @@ //! The Tokio runtime. //! -//! Unlike other Rust programs, asynchronous applications require -//! runtime support. In particular, the following runtime services are -//! necessary: +//! Unlike other Rust programs, asynchronous applications require runtime +//! support. In particular, the following runtime services are necessary: //! //! * An **I/O event loop**, called the driver, which drives I/O resources and //! dispatches I/O events to tasks that depend on them. @@ -10,14 +9,14 @@ //! * A **timer** for scheduling work to run after a set period of time. //! //! Tokio's [`Runtime`] bundles all of these services as a single type, allowing -//! them to be started, shut down, and configured together. However, most -//! applications won't need to use [`Runtime`] directly. Instead, they can -//! use the [`tokio::main`] attribute macro, which creates a [`Runtime`] under -//! the hood. +//! them to be started, shut down, and configured together. However, often it is +//! not required to configure a [`Runtime`] manually, and user may just use the +//! [`tokio::main`] attribute macro, which creates a [`Runtime`] under the hood. //! //! # Usage //! -//! Most applications will use the [`tokio::main`] attribute macro. +//! When no fine tuning is required, the [`tokio::main`] attribute macro can be +//! used. //! //! ```no_run //! use tokio::net::TcpListener; @@ -25,7 +24,7 @@ //! //! #[tokio::main] //! async fn main() -> Result<(), Box<dyn std::error::Error>> { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; @@ -69,11 +68,11 @@ //! //! fn main() -> Result<(), Box<dyn std::error::Error>> { //! // Create the runtime -//! let mut rt = Runtime::new()?; +//! let rt = Runtime::new()?; //! //! // Spawn the root task //! rt.block_on(async { -//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?; +//! let listener = TcpListener::bind("127.0.0.1:8080").await?; //! //! loop { //! let (mut socket, _) = listener.accept().await?; @@ -111,48 +110,38 @@ //! applications. The [runtime builder] or `#[tokio::main]` attribute may be //! used to select which scheduler to use. //! -//! #### Basic Scheduler +//! #### Multi-Thread Scheduler //! -//! The basic scheduler provides a _single-threaded_ future executor. All tasks -//! will be created and executed on the current thread. The basic scheduler -//! requires the `rt-core` feature flag, and can be selected using the -//! [`Builder::basic_scheduler`] method: +//! The multi-thread scheduler executes futures on a _thread pool_, using a +//! work-stealing strategy. By default, it will start a worker thread for each +//! CPU core available on the system. This tends to be the ideal configurations +//! for most applications. The multi-thread scheduler requires the `rt-multi-thread` +//! feature flag, and is selected by default: //! ``` //! use tokio::runtime; //! //! # fn main() -> Result<(), Box<dyn std::error::Error>> { -//! let basic_rt = runtime::Builder::new() -//! .basic_scheduler() -//! .build()?; +//! let threaded_rt = runtime::Runtime::new()?; //! # Ok(()) } //! ``` //! -//! If the `rt-core` feature is enabled and `rt-threaded` is not, -//! [`Runtime::new`] will return a basic scheduler runtime by default. +//! Most applications should use the multi-thread scheduler, except in some +//! niche use-cases, such as when running only a single thread is required. //! -//! #### Threaded Scheduler +//! #### Current-Thread Scheduler //! -//! The threaded scheduler executes futures on a _thread pool_, using a -//! work-stealing strategy. By default, it will start a worker thread for each -//! CPU core available on the system. This tends to be the ideal configurations -//! for most applications. The threaded scheduler requires the `rt-threaded` feature -//! flag, and can be selected using the [`Builder::threaded_scheduler`] method: +//! The current-thread scheduler provides a _single-threaded_ future executor. +//! All tasks will be created and executed on the current thread. This requires +//! the `rt` feature flag. //! ``` //! use tokio::runtime; //! //! # fn main() -> Result<(), Box<dyn std::error::Error>> { -//! let threaded_rt = runtime::Builder::new() -//! .threaded_scheduler() +//! let basic_rt = runtime::Builder::new_current_thread() //! .build()?; //! # Ok(()) } //! ``` //! -//! If the `rt-threaded` feature flag is enabled, [`Runtime::new`] will return a -//! threaded scheduler runtime by default. -//! -//! Most applications should use the threaded scheduler, except in some niche -//! use-cases, such as when running only a single thread is required. -//! //! #### Resource drivers //! //! When configuring a runtime by hand, no resource drivers are enabled by @@ -164,8 +153,8 @@ //! ## Lifetime of spawned threads //! //! The runtime may spawn threads depending on its configuration and usage. The -//! threaded scheduler spawns threads to schedule tasks and calls to -//! `spawn_blocking` spawn threads to run blocking operations. +//! multi-thread scheduler spawns threads to schedule tasks and for `spawn_blocking` +//! calls. //! //! While the `Runtime` is active, threads may shutdown after periods of being //! idle. Once `Runtime` is dropped, all runtime threads are forcibly shutdown. @@ -188,394 +177,380 @@ #[macro_use] mod tests; -pub(crate) mod context; +pub(crate) mod enter; + +pub(crate) mod task; -cfg_rt_core! { +cfg_rt! { mod basic_scheduler; use basic_scheduler::BasicScheduler; - pub(crate) mod task; -} - -mod blocking; -use blocking::BlockingPool; + mod blocking; + use blocking::BlockingPool; + pub(crate) use blocking::spawn_blocking; -cfg_blocking_impl! { - #[allow(unused_imports)] - pub(crate) use blocking::{spawn_blocking, try_spawn_blocking}; -} + mod builder; + pub use self::builder::Builder; -mod builder; -pub use self::builder::Builder; + pub(crate) mod context; + pub(crate) mod driver; -pub(crate) mod enter; -use self::enter::enter; + use self::enter::enter; -mod handle; -pub use self::handle::{Handle, TryCurrentError}; + mod handle; + use handle::Handle; -mod io; + mod spawner; + use self::spawner::Spawner; +} -cfg_rt_threaded! { +cfg_rt_multi_thread! { mod park; use park::Parker; } -mod shell; -use self::shell::Shell; - -mod spawner; -use self::spawner::Spawner; - -mod time; - -cfg_rt_threaded! { +cfg_rt_multi_thread! { mod queue; pub(crate) mod thread_pool; use self::thread_pool::ThreadPool; } -cfg_rt_core! { +cfg_rt! { use crate::task::JoinHandle; -} - -use std::future::Future; -use std::time::Duration; - -/// The Tokio runtime. -/// -/// The runtime provides an I/O driver, task scheduler, [timer], and blocking -/// pool, necessary for running asynchronous tasks. -/// -/// Instances of `Runtime` can be created using [`new`] or [`Builder`]. However, -/// most users will use the `#[tokio::main]` annotation on their entry point instead. -/// -/// See [module level][mod] documentation for more details. -/// -/// # Shutdown -/// -/// Shutting down the runtime is done by dropping the value. The current thread -/// will block until the shut down operation has completed. -/// -/// * Drain any scheduled work queues. -/// * Drop any futures that have not yet completed. -/// * Drop the reactor. -/// -/// Once the reactor has dropped, any outstanding I/O resources bound to -/// that reactor will no longer function. Calling any method on them will -/// result in an error. -/// -/// [timer]: crate::time -/// [mod]: index.html -/// [`new`]: method@Self::new -/// [`Builder`]: struct@Builder -/// [`tokio::run`]: fn@run -#[derive(Debug)] -pub struct Runtime { - /// Task executor - kind: Kind, - - /// Handle to runtime, also contains driver handles - handle: Handle, - - /// Blocking pool handle, used to signal shutdown - blocking_pool: BlockingPool, -} - -/// The runtime executor is either a thread-pool or a current-thread executor. -#[derive(Debug)] -enum Kind { - /// Not able to execute concurrent tasks. This variant is mostly used to get - /// access to the driver handles. - Shell(Shell), - /// Execute all tasks on the current-thread. - #[cfg(feature = "rt-core")] - Basic(BasicScheduler<time::Driver>), + use std::future::Future; + use std::time::Duration; - /// Execute tasks across multiple threads. - #[cfg(feature = "rt-threaded")] - ThreadPool(ThreadPool), -} - -/// After thread starts / before thread stops -type Callback = std::sync::Arc<dyn Fn() + Send + Sync>; - -impl Runtime { - /// Create a new runtime instance with default configuration values. + /// The Tokio runtime. /// - /// This results in a scheduler, I/O driver, and time driver being - /// initialized. The type of scheduler used depends on what feature flags - /// are enabled: if the `rt-threaded` feature is enabled, the [threaded - /// scheduler] is used, while if only the `rt-core` feature is enabled, the - /// [basic scheduler] is used instead. + /// The runtime provides an I/O driver, task scheduler, [timer], and + /// blocking pool, necessary for running asynchronous tasks. /// - /// If the threaded scheduler is selected, it will not spawn - /// any worker threads until it needs to, i.e. tasks are scheduled to run. - /// - /// Most applications will not need to call this function directly. Instead, - /// they will use the [`#[tokio::main]` attribute][main]. When more complex - /// configuration is necessary, the [runtime builder] may be used. + /// Instances of `Runtime` can be created using [`new`], or [`Builder`]. + /// However, most users will use the `#[tokio::main]` annotation on their + /// entry point instead. /// /// See [module level][mod] documentation for more details. /// - /// # Examples - /// - /// Creating a new `Runtime` with default configuration values. + /// # Shutdown /// - /// ``` - /// use tokio::runtime::Runtime; + /// Shutting down the runtime is done by dropping the value. The current + /// thread will block until the shut down operation has completed. /// - /// let rt = Runtime::new() - /// .unwrap(); + /// * Drain any scheduled work queues. + /// * Drop any futures that have not yet completed. + /// * Drop the reactor. /// - /// // Use the runtime... - /// ``` + /// Once the reactor has dropped, any outstanding I/O resources bound to + /// that reactor will no longer function. Calling any method on them will + /// result in an error. /// - /// [mod]: index.html - /// [main]: ../attr.main.html - /// [threaded scheduler]: index.html#threaded-scheduler - /// [basic scheduler]: index.html#basic-scheduler - /// [runtime builder]: crate::runtime::Builder - pub fn new() -> io::Result<Runtime> { - #[cfg(feature = "rt-threaded")] - let ret = Builder::new().threaded_scheduler().enable_all().build(); - - #[cfg(all(not(feature = "rt-threaded"), feature = "rt-core"))] - let ret = Builder::new().basic_scheduler().enable_all().build(); - - #[cfg(not(feature = "rt-core"))] - let ret = Builder::new().enable_all().build(); - - ret - } - - /// Spawn a future onto the Tokio runtime. + /// # Sharing /// - /// This spawns the given future onto the runtime's executor, usually a - /// thread pool. The thread pool is then responsible for polling the future - /// until it completes. + /// The Tokio runtime implements `Sync` and `Send` to allow you to wrap it + /// in a `Arc`. Most fn take `&self` to allow you to call them concurrently + /// accross multiple threads. /// - /// See [module level][mod] documentation for more details. + /// Calls to `shutdown` and `shutdown_timeout` require exclusive ownership of + /// the runtime type and this can be achieved via `Arc::try_unwrap` when only + /// one strong count reference is left over. /// + /// [timer]: crate::time /// [mod]: index.html - /// - /// # Examples - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// # fn dox() { - /// // Create the runtime - /// let rt = Runtime::new().unwrap(); - /// - /// // Spawn a future onto the runtime - /// rt.spawn(async { - /// println!("now running on a worker thread"); - /// }); - /// # } - /// ``` - /// - /// # Panics - /// - /// This function will not panic unless task execution is disabled on the - /// executor. This can only happen if the runtime was built using - /// [`Builder`] without picking either [`basic_scheduler`] or - /// [`threaded_scheduler`]. - /// + /// [`new`]: method@Self::new /// [`Builder`]: struct@Builder - /// [`threaded_scheduler`]: fn@Builder::threaded_scheduler - /// [`basic_scheduler`]: fn@Builder::basic_scheduler - #[cfg(feature = "rt-core")] - pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - match &self.kind { - Kind::Shell(_) => panic!("task execution disabled"), - #[cfg(feature = "rt-threaded")] - Kind::ThreadPool(exec) => exec.spawn(future), - Kind::Basic(exec) => exec.spawn(future), - } - } + #[derive(Debug)] + pub struct Runtime { + /// Task executor + kind: Kind, - /// Run a future to completion on the Tokio runtime. This is the runtime's - /// entry point. - /// - /// This runs the given future on the runtime, blocking until it is - /// complete, and yielding its resolved result. Any tasks or timers which - /// the future spawns internally will be executed on the runtime. - /// - /// `&mut` is required as calling `block_on` **may** result in advancing the - /// state of the runtime. The details depend on how the runtime is - /// configured. [`runtime::Handle::block_on`][handle] provides a version - /// that takes `&self`. - /// - /// This method may not be called from an asynchronous context. - /// - /// # Panics - /// - /// This function panics if the provided future panics, or if called within an - /// asynchronous execution context. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::runtime::Runtime; - /// - /// // Create the runtime - /// let mut rt = Runtime::new().unwrap(); - /// - /// // Execute the future, blocking the current thread until completion - /// rt.block_on(async { - /// println!("hello"); - /// }); - /// ``` - /// - /// [handle]: fn@Handle::block_on - pub fn block_on<F: Future>(&mut self, future: F) -> F::Output { - let kind = &mut self.kind; + /// Handle to runtime, also contains driver handles + handle: Handle, - self.handle.enter(|| match kind { - Kind::Shell(exec) => exec.block_on(future), - #[cfg(feature = "rt-core")] - Kind::Basic(exec) => exec.block_on(future), - #[cfg(feature = "rt-threaded")] - Kind::ThreadPool(exec) => exec.block_on(future), - }) + /// Blocking pool handle, used to signal shutdown + blocking_pool: BlockingPool, } - /// Enter the runtime context. This allows you to construct types that must - /// have an executor available on creation such as [`Delay`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`]. - /// - /// This function is also available as [`Handle::enter`]. - /// - /// [`Delay`]: struct@crate::time::Delay - /// [`TcpStream`]: struct@crate::net::TcpStream - /// [`Handle::enter`]: fn@crate::runtime::Handle::enter - /// [`tokio::spawn`]: fn@crate::spawn - /// - /// # Example - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// fn function_that_spawns(msg: String) { - /// // Had we not used `rt.enter` below, this would panic. - /// tokio::spawn(async move { - /// println!("{}", msg); - /// }); - /// } + /// Runtime context guard. /// - /// fn main() { - /// let rt = Runtime::new().unwrap(); - /// - /// let s = "Hello World!".to_string(); - /// - /// // By entering the context, we tie `tokio::spawn` to this executor. - /// rt.enter(|| function_that_spawns(s)); - /// } - /// ``` - pub fn enter<F, R>(&self, f: F) -> R - where - F: FnOnce() -> R, - { - self.handle.enter(f) + /// Returned by [`Runtime::enter`], the context guard exits the runtime + /// context on drop. + #[derive(Debug)] + pub struct EnterGuard<'a> { + rt: &'a Runtime, + guard: context::EnterGuard, } - /// Return a handle to the runtime's spawner. - /// - /// The returned handle can be used to spawn tasks that run on this runtime, and can - /// be cloned to allow moving the `Handle` to other threads. - /// - /// # Examples - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// let rt = Runtime::new() - /// .unwrap(); - /// - /// let handle = rt.handle(); - /// - /// handle.spawn(async { println!("hello"); }); - /// ``` - pub fn handle(&self) -> &Handle { - &self.handle - } + /// The runtime executor is either a thread-pool or a current-thread executor. + #[derive(Debug)] + enum Kind { + /// Execute all tasks on the current-thread. + CurrentThread(BasicScheduler<driver::Driver>), - /// Shutdown the runtime, waiting for at most `duration` for all spawned - /// task to shutdown. - /// - /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to - /// shutdown in a timely fashion. However, dropping a `Runtime` will wait - /// indefinitely for all tasks to terminate, and there are cases where a long - /// blocking task has been spawned, which can block dropping `Runtime`. - /// - /// In this case, calling `shutdown_timeout` with an explicit wait timeout - /// can work. The `shutdown_timeout` will signal all tasks to shutdown and - /// will wait for at most `duration` for all spawned tasks to terminate. If - /// `timeout` elapses before all tasks are dropped, the function returns and - /// outstanding tasks are potentially leaked. - /// - /// # Examples - /// - /// ``` - /// use tokio::runtime::Runtime; - /// use tokio::task; - /// - /// use std::thread; - /// use std::time::Duration; - /// - /// fn main() { - /// let mut runtime = Runtime::new().unwrap(); - /// - /// runtime.block_on(async move { - /// task::spawn_blocking(move || { - /// thread::sleep(Duration::from_secs(10_000)); - /// }); - /// }); - /// - /// runtime.shutdown_timeout(Duration::from_millis(100)); - /// } - /// ``` - pub fn shutdown_timeout(self, duration: Duration) { - let Runtime { - mut blocking_pool, .. - } = self; - blocking_pool.shutdown(Some(duration)); + /// Execute tasks across multiple threads. + #[cfg(feature = "rt-multi-thread")] + ThreadPool(ThreadPool), } - /// Shutdown the runtime, without waiting for any spawned tasks to shutdown. - /// - /// This can be useful if you want to drop a runtime from within another runtime. - /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks - /// to complete, which would normally not be permitted within an asynchronous context. - /// By calling `shutdown_background()`, you can drop the runtime from such a context. - /// - /// Note however, that because we do not wait for any blocking tasks to complete, this - /// may result in a resource leak (in that any blocking tasks are still running until they - /// return. - /// - /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. - /// - /// ``` - /// use tokio::runtime::Runtime; - /// - /// fn main() { - /// let mut runtime = Runtime::new().unwrap(); - /// - /// runtime.block_on(async move { - /// let inner_runtime = Runtime::new().unwrap(); - /// // ... - /// inner_runtime.shutdown_background(); - /// }); - /// } - /// ``` - pub fn shutdown_background(self) { - self.shutdown_timeout(Duration::from_nanos(0)) + /// After thread starts / before thread stops + type Callback = std::sync::Arc<dyn Fn() + Send + Sync>; + + impl Runtime { + /// Create a new runtime instance with default configuration values. + /// + /// This results in the multi threaded scheduler, I/O driver, and time driver being + /// initialized. + /// + /// Most applications will not need to call this function directly. Instead, + /// they will use the [`#[tokio::main]` attribute][main]. When a more complex + /// configuration is necessary, the [runtime builder] may be used. + /// + /// See [module level][mod] documentation for more details. + /// + /// # Examples + /// + /// Creating a new `Runtime` with default configuration values. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// let rt = Runtime::new() + /// .unwrap(); + /// + /// // Use the runtime... + /// ``` + /// + /// [mod]: index.html + /// [main]: ../attr.main.html + /// [threaded scheduler]: index.html#threaded-scheduler + /// [basic scheduler]: index.html#basic-scheduler + /// [runtime builder]: crate::runtime::Builder + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))] + pub fn new() -> std::io::Result<Runtime> { + Builder::new_multi_thread().enable_all().build() + } + + /// Spawn a future onto the Tokio runtime. + /// + /// This spawns the given future onto the runtime's executor, usually a + /// thread pool. The thread pool is then responsible for polling the future + /// until it completes. + /// + /// See [module level][mod] documentation for more details. + /// + /// [mod]: index.html + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a future onto the runtime + /// rt.spawn(async { + /// println!("now running on a worker thread"); + /// }); + /// # } + /// ``` + pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match &self.kind { + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(exec) => exec.spawn(future), + Kind::CurrentThread(exec) => exec.spawn(future), + } + } + + /// Run the provided function on an executor dedicated to blocking operations. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// # fn dox() { + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Spawn a blocking function onto the runtime + /// rt.spawn_blocking(|| { + /// println!("now running on a worker thread"); + /// }); + /// # } + pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + { + self.handle.spawn_blocking(func) + } + + /// Run a future to completion on the Tokio runtime. This is the + /// runtime's entry point. + /// + /// This runs the given future on the runtime, blocking until it is + /// complete, and yielding its resolved result. Any tasks or timers + /// which the future spawns internally will be executed on the runtime. + /// + /// # Multi thread scheduler + /// + /// When the multi thread scheduler is used this will allow futures + /// to run within the io driver and timer context of the overall runtime. + /// + /// # Current thread scheduler + /// + /// When the current thread scheduler is enabled `block_on` + /// can be called concurrently from multiple threads. The first call + /// will take ownership of the io and timer drivers. This means + /// other threads which do not own the drivers will hook into that one. + /// When the first `block_on` completes, other threads will be able to + /// "steal" the driver to allow continued execution of their futures. + /// + /// # Panics + /// + /// This function panics if the provided future panics, or if not called within an + /// asynchronous execution context. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::runtime::Runtime; + /// + /// // Create the runtime + /// let rt = Runtime::new().unwrap(); + /// + /// // Execute the future, blocking the current thread until completion + /// rt.block_on(async { + /// println!("hello"); + /// }); + /// ``` + /// + /// [handle]: fn@Handle::block_on + pub fn block_on<F: Future>(&self, future: F) -> F::Output { + let _enter = self.enter(); + + match &self.kind { + Kind::CurrentThread(exec) => exec.block_on(future), + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(exec) => exec.block_on(future), + } + } + + /// Enter the runtime context. + /// + /// This allows you to construct types that must have an executor + /// available on creation such as [`Sleep`] or [`TcpStream`]. It will + /// also allow you to call methods such as [`tokio::spawn`]. + /// + /// [`Sleep`]: struct@crate::time::Sleep + /// [`TcpStream`]: struct@crate::net::TcpStream + /// [`tokio::spawn`]: fn@crate::spawn + /// + /// # Example + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn function_that_spawns(msg: String) { + /// // Had we not used `rt.enter` below, this would panic. + /// tokio::spawn(async move { + /// println!("{}", msg); + /// }); + /// } + /// + /// fn main() { + /// let rt = Runtime::new().unwrap(); + /// + /// let s = "Hello World!".to_string(); + /// + /// // By entering the context, we tie `tokio::spawn` to this executor. + /// let _guard = rt.enter(); + /// function_that_spawns(s); + /// } + /// ``` + pub fn enter(&self) -> EnterGuard<'_> { + EnterGuard { + rt: self, + guard: context::enter(self.handle.clone()), + } + } + + /// Shutdown the runtime, waiting for at most `duration` for all spawned + /// task to shutdown. + /// + /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to + /// shutdown in a timely fashion. However, dropping a `Runtime` will wait + /// indefinitely for all tasks to terminate, and there are cases where a long + /// blocking task has been spawned, which can block dropping `Runtime`. + /// + /// In this case, calling `shutdown_timeout` with an explicit wait timeout + /// can work. The `shutdown_timeout` will signal all tasks to shutdown and + /// will wait for at most `duration` for all spawned tasks to terminate. If + /// `timeout` elapses before all tasks are dropped, the function returns and + /// outstanding tasks are potentially leaked. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Runtime; + /// use tokio::task; + /// + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// task::spawn_blocking(move || { + /// thread::sleep(Duration::from_secs(10_000)); + /// }); + /// }); + /// + /// runtime.shutdown_timeout(Duration::from_millis(100)); + /// } + /// ``` + pub fn shutdown_timeout(mut self, duration: Duration) { + // Wakeup and shutdown all the worker threads + self.handle.spawner.shutdown(); + self.blocking_pool.shutdown(Some(duration)); + } + + /// Shutdown the runtime, without waiting for any spawned tasks to shutdown. + /// + /// This can be useful if you want to drop a runtime from within another runtime. + /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks + /// to complete, which would normally not be permitted within an asynchronous context. + /// By calling `shutdown_background()`, you can drop the runtime from such a context. + /// + /// Note however, that because we do not wait for any blocking tasks to complete, this + /// may result in a resource leak (in that any blocking tasks are still running until they + /// return. + /// + /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`. + /// + /// ``` + /// use tokio::runtime::Runtime; + /// + /// fn main() { + /// let runtime = Runtime::new().unwrap(); + /// + /// runtime.block_on(async move { + /// let inner_runtime = Runtime::new().unwrap(); + /// // ... + /// inner_runtime.shutdown_background(); + /// }); + /// } + /// ``` + pub fn shutdown_background(self) { + self.shutdown_timeout(Duration::from_nanos(0)) + } } } diff --git a/src/runtime/park.rs b/src/runtime/park.rs index ee437d1..033b9f2 100644 --- a/src/runtime/park.rs +++ b/src/runtime/park.rs @@ -6,7 +6,7 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Condvar, Mutex}; use crate::loom::thread; use crate::park::{Park, Unpark}; -use crate::runtime::time; +use crate::runtime::driver::Driver; use crate::util::TryLock; use std::sync::atomic::Ordering::SeqCst; @@ -42,14 +42,14 @@ const NOTIFIED: usize = 3; /// Shared across multiple Parker handles struct Shared { /// Shared driver. Only one thread at a time can use this - driver: TryLock<time::Driver>, + driver: TryLock<Driver>, /// Unpark handle - handle: <time::Driver as Park>::Unpark, + handle: <Driver as Park>::Unpark, } impl Parker { - pub(crate) fn new(driver: time::Driver) -> Parker { + pub(crate) fn new(driver: Driver) -> Parker { let handle = driver.unpark(); Parker { @@ -104,6 +104,10 @@ impl Park for Parker { Ok(()) } } + + fn shutdown(&mut self) { + self.inner.shutdown(); + } } impl Unpark for Unparker { @@ -138,7 +142,7 @@ impl Inner { fn park_condvar(&self) { // Otherwise we need to coordinate going to sleep - let mut m = self.mutex.lock().unwrap(); + let mut m = self.mutex.lock(); match self .state @@ -176,7 +180,7 @@ impl Inner { } } - fn park_driver(&self, driver: &mut time::Driver) { + fn park_driver(&self, driver: &mut Driver) { match self .state .compare_exchange(EMPTY, PARKED_DRIVER, SeqCst, SeqCst) @@ -234,7 +238,7 @@ impl Inner { // Releasing `lock` before the call to `notify_one` means that when the // parked thread wakes it doesn't get woken only to have to wait for us // to release `lock`. - drop(self.mutex.lock().unwrap()); + drop(self.mutex.lock()); self.condvar.notify_one() } @@ -242,4 +246,12 @@ impl Inner { fn unpark_driver(&self) { self.shared.handle.unpark(); } + + fn shutdown(&self) { + if let Some(mut driver) = self.shared.driver.try_lock() { + driver.shutdown(); + } + + self.condvar.notify_all(); + } } diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs index c654514..cdf4009 100644 --- a/src/runtime/queue.rs +++ b/src/runtime/queue.rs @@ -481,7 +481,7 @@ impl<T: 'static> Inject<T> { /// Close the injection queue, returns `true` if the queue is open when the /// transition is made. pub(super) fn close(&self) -> bool { - let mut p = self.pointers.lock().unwrap(); + let mut p = self.pointers.lock(); if p.is_closed { return false; @@ -492,7 +492,7 @@ impl<T: 'static> Inject<T> { } pub(super) fn is_closed(&self) -> bool { - self.pointers.lock().unwrap().is_closed + self.pointers.lock().is_closed } pub(super) fn len(&self) -> usize { @@ -502,7 +502,7 @@ impl<T: 'static> Inject<T> { /// Pushes a value into the queue. pub(super) fn push(&self, task: task::Notified<T>) { // Acquire queue lock - let mut p = self.pointers.lock().unwrap(); + let mut p = self.pointers.lock(); if p.is_closed { // Drop the mutex to avoid a potential deadlock when @@ -541,7 +541,7 @@ impl<T: 'static> Inject<T> { debug_assert!(get_next(batch_tail).is_none()); - let mut p = self.pointers.lock().unwrap(); + let mut p = self.pointers.lock(); if let Some(tail) = p.tail { set_next(tail, Some(batch_head)); @@ -566,7 +566,7 @@ impl<T: 'static> Inject<T> { return None; } - let mut p = self.pointers.lock().unwrap(); + let mut p = self.pointers.lock(); // It is possible to hit null here if another thread poped the last // task between us checking `len` and acquiring the lock. diff --git a/src/runtime/shell.rs b/src/runtime/shell.rs index a65869d..486d4fa 100644 --- a/src/runtime/shell.rs +++ b/src/runtime/shell.rs @@ -1,52 +1,84 @@ #![allow(clippy::redundant_clone)] +use crate::future::poll_fn; use crate::park::{Park, Unpark}; -use crate::runtime::enter; -use crate::runtime::time; +use crate::runtime::driver::Driver; +use crate::sync::Notify; use crate::util::{waker_ref, Wake}; -use std::future::Future; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::task::Context; -use std::task::Poll::Ready; +use std::task::Poll::{Pending, Ready}; +use std::{future::Future, sync::PoisonError}; #[derive(Debug)] pub(super) struct Shell { - driver: time::Driver, + driver: Mutex<Option<Driver>>, + + notify: Notify, /// TODO: don't store this unpark: Arc<Handle>, } #[derive(Debug)] -struct Handle(<time::Driver as Park>::Unpark); +struct Handle(<Driver as Park>::Unpark); impl Shell { - pub(super) fn new(driver: time::Driver) -> Shell { + pub(super) fn new(driver: Driver) -> Shell { let unpark = Arc::new(Handle(driver.unpark())); - Shell { driver, unpark } + Shell { + driver: Mutex::new(Some(driver)), + notify: Notify::new(), + unpark, + } } - pub(super) fn block_on<F>(&mut self, f: F) -> F::Output + pub(super) fn block_on<F>(&self, f: F) -> F::Output where F: Future, { - let _e = enter(true); + let mut enter = crate::runtime::enter(true); pin!(f); - let waker = waker_ref(&self.unpark); - let mut cx = Context::from_waker(&waker); - loop { - if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { - return v; - } + if let Some(driver) = &mut self.take_driver() { + return driver.block_on(f); + } else { + let notified = self.notify.notified(); + pin!(notified); + + if let Some(out) = enter + .block_on(poll_fn(|cx| { + if notified.as_mut().poll(cx).is_ready() { + return Ready(None); + } + + if let Ready(out) = f.as_mut().poll(cx) { + return Ready(Some(out)); + } - self.driver.park().unwrap(); + Pending + })) + .expect("Failed to `Enter::block_on`") + { + return out; + } + } } } + + fn take_driver(&self) -> Option<DriverGuard<'_>> { + let mut lock = self.driver.lock().unwrap(); + let driver = lock.take()?; + + Some(DriverGuard { + inner: Some(driver), + shell: &self, + }) + } } impl Wake for Handle { @@ -60,3 +92,41 @@ impl Wake for Handle { arc_self.0.unpark(); } } + +struct DriverGuard<'a> { + inner: Option<Driver>, + shell: &'a Shell, +} + +impl DriverGuard<'_> { + fn block_on<F: Future>(&mut self, f: F) -> F::Output { + let driver = self.inner.as_mut().unwrap(); + + pin!(f); + + let waker = waker_ref(&self.shell.unpark); + let mut cx = Context::from_waker(&waker); + + loop { + if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { + return v; + } + + driver.park().unwrap(); + } + } +} + +impl Drop for DriverGuard<'_> { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + self.shell + .driver + .lock() + .unwrap_or_else(PoisonError::into_inner) + .replace(inner); + + self.shell.notify.notify_one(); + } + } +} diff --git a/src/runtime/spawner.rs b/src/runtime/spawner.rs index d136945..a37c667 100644 --- a/src/runtime/spawner.rs +++ b/src/runtime/spawner.rs @@ -1,24 +1,34 @@ -cfg_rt_core! { +cfg_rt! { use crate::runtime::basic_scheduler; use crate::task::JoinHandle; use std::future::Future; } -cfg_rt_threaded! { +cfg_rt_multi_thread! { use crate::runtime::thread_pool; } #[derive(Debug, Clone)] pub(crate) enum Spawner { - Shell, - #[cfg(feature = "rt-core")] + #[cfg(feature = "rt")] Basic(basic_scheduler::Spawner), - #[cfg(feature = "rt-threaded")] + #[cfg(feature = "rt-multi-thread")] ThreadPool(thread_pool::Spawner), } -cfg_rt_core! { +impl Spawner { + pub(crate) fn shutdown(&mut self) { + #[cfg(feature = "rt-multi-thread")] + { + if let Spawner::ThreadPool(spawner) = self { + spawner.shutdown(); + } + } + } +} + +cfg_rt! { impl Spawner { pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where @@ -26,10 +36,9 @@ cfg_rt_core! { F::Output: Send + 'static, { match self { - Spawner::Shell => panic!("spawning not enabled for runtime"), - #[cfg(feature = "rt-core")] + #[cfg(feature = "rt")] Spawner::Basic(spawner) => spawner.spawn(future), - #[cfg(feature = "rt-threaded")] + #[cfg(feature = "rt-multi-thread")] Spawner::ThreadPool(spawner) => spawner.spawn(future), } } diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs index f4756c2..dfa8764 100644 --- a/src/runtime/task/core.rs +++ b/src/runtime/task/core.rs @@ -269,7 +269,7 @@ impl<T: Future, S: Schedule> Core<T, S> { } } -cfg_rt_threaded! { +cfg_rt_multi_thread! { impl Header { pub(crate) fn shutdown(&self) { use crate::runtime::task::RawTask; diff --git a/src/runtime/task/error.rs b/src/runtime/task/error.rs index d5f65a4..177fe65 100644 --- a/src/runtime/task/error.rs +++ b/src/runtime/task/error.rs @@ -3,7 +3,7 @@ use std::fmt; use std::io; use std::sync::Mutex; -doc_rt_core! { +cfg_rt! { /// Task failed to execute to completion. pub struct JoinError { repr: Repr, @@ -16,25 +16,13 @@ enum Repr { } impl JoinError { - #[doc(hidden)] - #[deprecated] - pub fn cancelled() -> JoinError { - Self::cancelled2() - } - - pub(crate) fn cancelled2() -> JoinError { + pub(crate) fn cancelled() -> JoinError { JoinError { repr: Repr::Cancelled, } } - #[doc(hidden)] - #[deprecated] - pub fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { - Self::panic2(err) - } - - pub(crate) fn panic2(err: Box<dyn Any + Send + 'static>) -> JoinError { + pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { JoinError { repr: Repr::Panic(Mutex::new(err)), } @@ -42,10 +30,7 @@ impl JoinError { /// Returns true if the error was caused by the task being cancelled pub fn is_cancelled(&self) -> bool { - match &self.repr { - Repr::Cancelled => true, - _ => false, - } + matches!(&self.repr, Repr::Cancelled) } /// Returns true if the error was caused by the task panicking @@ -65,10 +50,7 @@ impl JoinError { /// } /// ``` pub fn is_panic(&self) -> bool { - match &self.repr { - Repr::Panic(_) => true, - _ => false, - } + matches!(&self.repr, Repr::Panic(_)) } /// Consumes the join error, returning the object with which the task panicked. diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs index e86b29e..208d48c 100644 --- a/src/runtime/task/harness.rs +++ b/src/runtime/task/harness.rs @@ -102,7 +102,7 @@ where // If the task is cancelled, avoid polling it, instead signalling it // is complete. if snapshot.is_cancelled() { - Poll::Ready(Err(JoinError::cancelled2())) + Poll::Ready(Err(JoinError::cancelled())) } else { let res = guard.core.poll(self.header()); @@ -132,7 +132,7 @@ where } } Err(err) => { - self.complete(Err(JoinError::panic2(err)), snapshot.is_join_interested()); + self.complete(Err(JoinError::panic(err)), snapshot.is_join_interested()); } } } @@ -297,9 +297,9 @@ where // Dropping the future panicked, complete the join // handle with the panic to avoid dropping the panic // on the ground. - self.complete(Err(JoinError::panic2(err)), true); + self.complete(Err(JoinError::panic(err)), true); } else { - self.complete(Err(JoinError::cancelled2()), true); + self.complete(Err(JoinError::cancelled()), true); } } diff --git a/src/runtime/task/join.rs b/src/runtime/task/join.rs index 3c4aabb..dedfb38 100644 --- a/src/runtime/task/join.rs +++ b/src/runtime/task/join.rs @@ -6,7 +6,7 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -doc_rt_core! { +cfg_rt! { /// An owned permission to join on a task (await its termination). /// /// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for @@ -45,6 +45,71 @@ doc_rt_core! { /// # } /// ``` /// + /// The generic parameter `T` in `JoinHandle<T>` is the return type of the spawned task. + /// If the return value is an i32, the join handle has type `JoinHandle<i32>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<i32> = task::spawn(async { + /// 5 + 3 + /// }); + /// # } + /// + /// ``` + /// + /// If the task does not have a return value, the join handle has type `JoinHandle<()>`: + /// + /// ``` + /// use tokio::task; + /// + /// # async fn doc() { + /// let join_handle: task::JoinHandle<()> = task::spawn(async { + /// println!("I return nothing."); + /// }); + /// # } + /// ``` + /// + /// Note that `handle.await` doesn't give you the return type directly. It is wrapped in a + /// `Result` because panics in the spawned task are caught by Tokio. The `?` operator has + /// to be double chained to extract the returned value: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// Ok(5 + 3) + /// }); + /// + /// let result = join_handle.await??; + /// assert_eq!(result, 8); + /// Ok(()) + /// } + /// ``` + /// + /// If the task panics, the error is a [`JoinError`] that contains the panic: + /// + /// ``` + /// use tokio::task; + /// use std::io; + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async { + /// panic!("boom"); + /// }); + /// + /// let err = join_handle.await.unwrap_err(); + /// assert!(err.is_panic()); + /// Ok(()) + /// } + /// + /// ``` /// Child being detached and outliving its parent: /// /// ```no_run @@ -56,7 +121,7 @@ doc_rt_core! { /// let original_task = task::spawn(async { /// let _detached_task = task::spawn(async { /// // Here we sleep to make sure that the first task returns before. - /// time::delay_for(Duration::from_millis(10)).await; + /// time::sleep(Duration::from_millis(10)).await; /// // This will be called, even though the JoinHandle is dropped. /// println!("♫ Still alive ♫"); /// }); @@ -68,13 +133,14 @@ doc_rt_core! { /// // We make sure that the new task has time to run, before the main /// // task returns. /// - /// time::delay_for(Duration::from_millis(1000)).await; + /// time::sleep(Duration::from_millis(1000)).await; /// # } /// ``` /// /// [`task::spawn`]: crate::task::spawn() /// [`task::spawn_blocking`]: crate::task::spawn_blocking /// [`std::thread::JoinHandle`]: std::thread::JoinHandle + /// [`JoinError`]: crate::task::JoinError pub struct JoinHandle<T> { raw: Option<RawTask>, _p: PhantomData<T>, @@ -91,6 +157,44 @@ impl<T> JoinHandle<T> { _p: PhantomData, } } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will complete with a `Err(JoinError::Cancelled)`. + /// + /// ```rust + /// use tokio::time; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut handles = Vec::new(); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// true + /// })); + /// + /// handles.push(tokio::spawn(async { + /// time::sleep(time::Duration::from_secs(10)).await; + /// false + /// })); + /// + /// for handle in &handles { + /// handle.abort(); + /// } + /// + /// for handle in handles { + /// assert!(handle.await.unwrap_err().is_cancelled()); + /// } + /// } + /// ``` + pub fn abort(&self) { + if let Some(raw) = self.raw { + raw.shutdown(); + } + } } impl<T> Unpin for JoinHandle<T> {} diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 17b5157..7b49e95 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -21,7 +21,7 @@ use self::state::State; mod waker; -cfg_rt_threaded! { +cfg_rt_multi_thread! { mod stack; pub(crate) use self::stack::TransferStack; } @@ -79,25 +79,27 @@ pub(crate) trait Schedule: Sync + Sized + 'static { } } -/// Create a new task with an associated join handle -pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) -where - T: Future + Send + 'static, - S: Schedule, -{ - let raw = RawTask::new::<_, S>(task); +cfg_rt! { + /// Create a new task with an associated join handle + pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) + where + T: Future + Send + 'static, + S: Schedule, + { + let raw = RawTask::new::<_, S>(task); - let task = Task { - raw, - _p: PhantomData, - }; + let task = Task { + raw, + _p: PhantomData, + }; - let join = JoinHandle::new(raw); + let join = JoinHandle::new(raw); - (Notified(task), join) + (Notified(task), join) + } } -cfg_rt_util! { +cfg_rt! { /// Create a new `!Send` task with an associated join handle pub(crate) unsafe fn joinable_local<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) where @@ -130,7 +132,7 @@ impl<S: 'static> Task<S> { } } -cfg_rt_threaded! { +cfg_rt_multi_thread! { impl<S: 'static> Notified<S> { pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { Notified(Task::from_raw(ptr)) diff --git a/src/runtime/tests/loom_blocking.rs b/src/runtime/tests/loom_blocking.rs index db7048e..8fb54c5 100644 --- a/src/runtime/tests/loom_blocking.rs +++ b/src/runtime/tests/loom_blocking.rs @@ -8,14 +8,15 @@ fn blocking_shutdown() { let v = Arc::new(()); let rt = mk_runtime(1); - rt.enter(|| { + { + let _enter = rt.enter(); for _ in 0..2 { let v = v.clone(); crate::task::spawn_blocking(move || { assert!(1 < Arc::strong_count(&v)); }); } - }); + } drop(rt); assert_eq!(1, Arc::strong_count(&v)); @@ -23,9 +24,8 @@ fn blocking_shutdown() { } fn mk_runtime(num_threads: usize) -> Runtime { - runtime::Builder::new() - .threaded_scheduler() - .core_threads(num_threads) + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) .build() .unwrap() } diff --git a/src/runtime/tests/loom_pool.rs b/src/runtime/tests/loom_pool.rs index c08658c..06ad641 100644 --- a/src/runtime/tests/loom_pool.rs +++ b/src/runtime/tests/loom_pool.rs @@ -178,7 +178,7 @@ mod group_b { #[test] fn join_output() { loom::model(|| { - let mut rt = mk_pool(1); + let rt = mk_pool(1); rt.block_on(async { let t = crate::spawn(track(async { "hello" })); @@ -192,7 +192,7 @@ mod group_b { #[test] fn poll_drop_handle_then_drop() { loom::model(|| { - let mut rt = mk_pool(1); + let rt = mk_pool(1); rt.block_on(async move { let mut t = crate::spawn(track(async { "hello" })); @@ -209,7 +209,7 @@ mod group_b { #[test] fn complete_block_on_under_load() { loom::model(|| { - let mut pool = mk_pool(1); + let pool = mk_pool(1); pool.block_on(async { // Trigger a re-schedule @@ -296,9 +296,8 @@ mod group_d { } fn mk_pool(num_threads: usize) -> Runtime { - runtime::Builder::new() - .threaded_scheduler() - .core_threads(num_threads) + runtime::Builder::new_multi_thread() + .worker_threads(num_threads) .build() .unwrap() } diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs index 82315a0..a34526f 100644 --- a/src/runtime/tests/task.rs +++ b/src/runtime/tests/task.rs @@ -1,5 +1,5 @@ use crate::runtime::task::{self, Schedule, Task}; -use crate::util::linked_list::LinkedList; +use crate::util::linked_list::{Link, LinkedList}; use crate::util::TryLock; use std::collections::VecDeque; @@ -72,7 +72,7 @@ struct Inner { struct Core { queue: VecDeque<task::Notified<Runtime>>, - tasks: LinkedList<Task<Runtime>>, + tasks: LinkedList<Task<Runtime>, <Task<Runtime> as Link>::Target>, } static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); diff --git a/src/runtime/thread_pool/atomic_cell.rs b/src/runtime/thread_pool/atomic_cell.rs index 2bda0fc..98847e6 100644 --- a/src/runtime/thread_pool/atomic_cell.rs +++ b/src/runtime/thread_pool/atomic_cell.rs @@ -22,7 +22,6 @@ impl<T> AtomicCell<T> { from_raw(old) } - #[cfg(feature = "blocking")] pub(super) fn set(&self, val: Box<T>) { let _ = self.swap(Some(val)); } diff --git a/src/runtime/thread_pool/idle.rs b/src/runtime/thread_pool/idle.rs index ae87ca4..6e692fd 100644 --- a/src/runtime/thread_pool/idle.rs +++ b/src/runtime/thread_pool/idle.rs @@ -55,7 +55,7 @@ impl Idle { } // Acquire the lock - let mut sleepers = self.sleepers.lock().unwrap(); + let mut sleepers = self.sleepers.lock(); // Check again, now that the lock is acquired if !self.notify_should_wakeup() { @@ -77,7 +77,7 @@ impl Idle { /// work. pub(super) fn transition_worker_to_parked(&self, worker: usize, is_searching: bool) -> bool { // Acquire the lock - let mut sleepers = self.sleepers.lock().unwrap(); + let mut sleepers = self.sleepers.lock(); // Decrement the number of unparked threads let ret = State::dec_num_unparked(&self.state, is_searching); @@ -112,7 +112,7 @@ impl Idle { /// Unpark a specific worker. This happens if tasks are submitted from /// within the worker's park routine. pub(super) fn unpark_worker_by_id(&self, worker_id: usize) { - let mut sleepers = self.sleepers.lock().unwrap(); + let mut sleepers = self.sleepers.lock(); for index in 0..sleepers.len() { if sleepers[index] == worker_id { @@ -128,7 +128,7 @@ impl Idle { /// Returns `true` if `worker_id` is contained in the sleep set pub(super) fn is_parked(&self, worker_id: usize) -> bool { - let sleepers = self.sleepers.lock().unwrap(); + let sleepers = self.sleepers.lock(); sleepers.contains(&worker_id) } diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs index ced9712..e39695a 100644 --- a/src/runtime/thread_pool/mod.rs +++ b/src/runtime/thread_pool/mod.rs @@ -9,9 +9,7 @@ use self::idle::Idle; mod worker; pub(crate) use worker::Launch; -cfg_blocking! { - pub(crate) use worker::block_in_place; -} +pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; use crate::runtime::task::{self, JoinHandle}; @@ -91,7 +89,7 @@ impl fmt::Debug for ThreadPool { impl Drop for ThreadPool { fn drop(&mut self) { - self.spawner.shared.close(); + self.spawner.shutdown(); } } @@ -108,6 +106,10 @@ impl Spawner { self.shared.schedule(task, false); handle } + + pub(crate) fn shutdown(&mut self) { + self.shared.close(); + } } impl fmt::Debug for Spawner { diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs index abe20da..bc544c9 100644 --- a/src/runtime/thread_pool/worker.rs +++ b/src/runtime/thread_pool/worker.rs @@ -9,10 +9,11 @@ use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime; +use crate::runtime::enter::EnterContext; use crate::runtime::park::{Parker, Unparker}; use crate::runtime::thread_pool::{AtomicCell, Idle}; use crate::runtime::{queue, task}; -use crate::util::linked_list::LinkedList; +use crate::util::linked_list::{Link, LinkedList}; use crate::util::FastRand; use std::cell::RefCell; @@ -53,7 +54,7 @@ struct Core { is_shutdown: bool, /// Tasks owned by the core - tasks: LinkedList<Task>, + tasks: LinkedList<Task, <Task as Link>::Target>, /// Parker /// @@ -172,104 +173,96 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) { (shared, launch) } -cfg_blocking! { - use crate::runtime::enter::EnterContext; - - pub(crate) fn block_in_place<F, R>(f: F) -> R - where - F: FnOnce() -> R, - { - // Try to steal the worker core back - struct Reset(coop::Budget); - - impl Drop for Reset { - fn drop(&mut self) { - CURRENT.with(|maybe_cx| { - if let Some(cx) = maybe_cx { - let core = cx.worker.core.take(); - let mut cx_core = cx.core.borrow_mut(); - assert!(cx_core.is_none()); - *cx_core = core; - - // Reset the task budget as we are re-entering the - // runtime. - coop::set(self.0); - } - }); - } +pub(crate) fn block_in_place<F, R>(f: F) -> R +where + F: FnOnce() -> R, +{ + // Try to steal the worker core back + struct Reset(coop::Budget); + + impl Drop for Reset { + fn drop(&mut self) { + CURRENT.with(|maybe_cx| { + if let Some(cx) = maybe_cx { + let core = cx.worker.core.take(); + let mut cx_core = cx.core.borrow_mut(); + assert!(cx_core.is_none()); + *cx_core = core; + + // Reset the task budget as we are re-entering the + // runtime. + coop::set(self.0); + } + }); } + } - let mut had_core = false; - let mut had_entered = false; + let mut had_entered = false; - CURRENT.with(|maybe_cx| { - match (crate::runtime::enter::context(), maybe_cx.is_some()) { - (EnterContext::Entered { .. }, true) => { - // We are on a thread pool runtime thread, so we just need to set up blocking. + CURRENT.with(|maybe_cx| { + match (crate::runtime::enter::context(), maybe_cx.is_some()) { + (EnterContext::Entered { .. }, true) => { + // We are on a thread pool runtime thread, so we just need to set up blocking. + had_entered = true; + } + (EnterContext::Entered { allow_blocking }, false) => { + // We are on an executor, but _not_ on the thread pool. + // That is _only_ okay if we are in a thread pool runtime's block_on method: + if allow_blocking { had_entered = true; - } - (EnterContext::Entered { allow_blocking }, false) => { - // We are on an executor, but _not_ on the thread pool. - // That is _only_ okay if we are in a thread pool runtime's block_on method: - if allow_blocking { - had_entered = true; - return; - } else { - // This probably means we are on the basic_scheduler or in a LocalSet, - // where it is _not_ okay to block. - panic!("can call blocking only when running on the multi-threaded runtime"); - } - } - (EnterContext::NotEntered, true) => { - // This is a nested call to block_in_place (we already exited). - // All the necessary setup has already been done. - return; - } - (EnterContext::NotEntered, false) => { - // We are outside of the tokio runtime, so blocking is fine. - // We can also skip all of the thread pool blocking setup steps. return; + } else { + // This probably means we are on the basic_scheduler or in a LocalSet, + // where it is _not_ okay to block. + panic!("can call blocking only when running on the multi-threaded runtime"); } } + (EnterContext::NotEntered, true) => { + // This is a nested call to block_in_place (we already exited). + // All the necessary setup has already been done. + return; + } + (EnterContext::NotEntered, false) => { + // We are outside of the tokio runtime, so blocking is fine. + // We can also skip all of the thread pool blocking setup steps. + return; + } + } - let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); - - // Get the worker core. If none is set, then blocking is fine! - let core = match cx.core.borrow_mut().take() { - Some(core) => core, - None => return, - }; - - // The parker should be set here - assert!(core.park.is_some()); + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); - // In order to block, the core must be sent to another thread for - // execution. - // - // First, move the core back into the worker's shared core slot. - cx.worker.core.set(core); - had_core = true; + // Get the worker core. If none is set, then blocking is fine! + let core = match cx.core.borrow_mut().take() { + Some(core) => core, + None => return, + }; - // Next, clone the worker handle and send it to a new thread for - // processing. - // - // Once the blocking task is done executing, we will attempt to - // steal the core back. - let worker = cx.worker.clone(); - runtime::spawn_blocking(move || run(worker)); - }); + // The parker should be set here + assert!(core.park.is_some()); + + // In order to block, the core must be sent to another thread for + // execution. + // + // First, move the core back into the worker's shared core slot. + cx.worker.core.set(core); + + // Next, clone the worker handle and send it to a new thread for + // processing. + // + // Once the blocking task is done executing, we will attempt to + // steal the core back. + let worker = cx.worker.clone(); + runtime::spawn_blocking(move || run(worker)); + }); - if had_core { - // Unset the current task's budget. Blocking sections are not - // constrained by task budgets. - let _reset = Reset(coop::stop()); + if had_entered { + // Unset the current task's budget. Blocking sections are not + // constrained by task budgets. + let _reset = Reset(coop::stop()); - crate::runtime::enter::exit(f) - } else if had_entered { - crate::runtime::enter::exit(f) - } else { - f() - } + crate::runtime::enter::exit(f) + } else { + f() } } @@ -576,6 +569,8 @@ impl Core { // Drain the queue while self.next_local_task().is_some() {} + + park.shutdown(); } fn drain_pending_drop(&mut self, worker: &Worker) { @@ -785,7 +780,7 @@ impl Shared { /// /// If all workers have reached this point, the final cleanup is performed. fn shutdown(&self, core: Box<Core>, worker: Arc<Worker>) { - let mut workers = self.shutdown_workers.lock().unwrap(); + let mut workers = self.shutdown_workers.lock(); workers.push((core, worker)); if workers.len() != self.remotes.len() { diff --git a/src/runtime/time.rs b/src/runtime/time.rs deleted file mode 100644 index c623d96..0000000 --- a/src/runtime/time.rs +++ /dev/null @@ -1,59 +0,0 @@ -//! Abstracts out the APIs necessary to `Runtime` for integrating the time -//! driver. When the `time` feature flag is **not** enabled. These APIs are -//! shells. This isolates the complexity of dealing with conditional -//! compilation. - -pub(crate) use variant::*; - -#[cfg(feature = "time")] -mod variant { - use crate::park::Either; - use crate::runtime::io; - use crate::time::{self, driver}; - - pub(crate) type Clock = time::Clock; - pub(crate) type Driver = Either<driver::Driver<io::Driver>, io::Driver>; - pub(crate) type Handle = Option<driver::Handle>; - - pub(crate) fn create_clock() -> Clock { - Clock::new() - } - - /// Create a new timer driver / handle pair - pub(crate) fn create_driver( - enable: bool, - io_driver: io::Driver, - clock: Clock, - ) -> (Driver, Handle) { - if enable { - let driver = driver::Driver::new(io_driver, clock); - let handle = driver.handle(); - - (Either::A(driver), Some(handle)) - } else { - (Either::B(io_driver), None) - } - } -} - -#[cfg(not(feature = "time"))] -mod variant { - use crate::runtime::io; - - pub(crate) type Clock = (); - pub(crate) type Driver = io::Driver; - pub(crate) type Handle = (); - - pub(crate) fn create_clock() -> Clock { - () - } - - /// Create a new timer driver / handle pair - pub(crate) fn create_driver( - _enable: bool, - io_driver: io::Driver, - _clock: Clock, - ) -> (Driver, Handle) { - (io_driver, ()) - } -} diff --git a/src/signal/registry.rs b/src/signal/registry.rs index 50edd2b..5d6f608 100644 --- a/src/signal/registry.rs +++ b/src/signal/registry.rs @@ -185,7 +185,7 @@ mod tests { #[test] fn smoke() { - let mut rt = rt(); + let rt = rt(); rt.block_on(async move { let registry = Registry::new(vec![ EventInfo::default(), @@ -247,7 +247,7 @@ mod tests { #[test] fn broadcast_cleans_up_disconnected_listeners() { - let mut rt = Runtime::new().unwrap(); + let rt = Runtime::new().unwrap(); rt.block_on(async { let registry = Registry::new(vec![EventInfo::default()]); @@ -306,7 +306,7 @@ mod tests { } fn rt() -> Runtime { - runtime::Builder::new().basic_scheduler().build().unwrap() + runtime::Builder::new_current_thread().build().unwrap() } async fn collect(mut rx: crate::sync::mpsc::Receiver<()>) -> Vec<()> { diff --git a/src/signal/unix.rs b/src/signal/unix.rs index b46b15c..aaaa75e 100644 --- a/src/signal/unix.rs +++ b/src/signal/unix.rs @@ -5,18 +5,21 @@ #![cfg(unix)] -use crate::io::{AsyncRead, PollEvented}; use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage}; +use crate::sync::mpsc::error::TryRecvError; use crate::sync::mpsc::{channel, Receiver}; use libc::c_int; -use mio_uds::UnixStream; +use mio::net::UnixStream; use std::io::{self, Error, ErrorKind, Write}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Once; use std::task::{Context, Poll}; +pub(crate) mod driver; +use self::driver::Handle; + pub(crate) type OsStorage = Vec<SignalInfo>; // Number of different unix signals @@ -202,9 +205,9 @@ impl Default for SignalInfo { /// The purpose of this signal handler is to primarily: /// /// 1. Flag that our specific signal was received (e.g. store an atomic flag) -/// 2. Wake up driver tasks by writing a byte to a pipe +/// 2. Wake up the driver by writing a byte to a pipe /// -/// Those two operations shoudl both be async-signal safe. +/// Those two operations should both be async-signal safe. fn action(globals: Pin<&'static Globals>, signal: c_int) { globals.record_event(signal as EventId); @@ -219,7 +222,7 @@ fn action(globals: Pin<&'static Globals>, signal: c_int) { /// /// This will register the signal handler if it hasn't already been registered, /// returning any error along the way if that fails. -fn signal_enable(signal: c_int) -> io::Result<()> { +fn signal_enable(signal: c_int, handle: Handle) -> io::Result<()> { if signal < 0 || signal_hook_registry::FORBIDDEN.contains(&signal) { return Err(Error::new( ErrorKind::Other, @@ -227,6 +230,9 @@ fn signal_enable(signal: c_int) -> io::Result<()> { )); } + // Check that we have a signal driver running + handle.check_inner()?; + let globals = globals(); let siginfo = match globals.storage().get(signal as EventId) { Some(slot) => slot, @@ -254,63 +260,6 @@ fn signal_enable(signal: c_int) -> io::Result<()> { } } -#[derive(Debug)] -struct Driver { - wakeup: PollEvented<UnixStream>, -} - -impl Driver { - fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { - // Drain the data from the pipe and maintain interest in getting more - self.drain(cx); - // Broadcast any signals which were received - globals().broadcast(); - - Poll::Pending - } -} - -impl Driver { - fn new() -> io::Result<Driver> { - // NB: We give each driver a "fresh" reciever file descriptor to avoid - // the issues described in alexcrichton/tokio-process#42. - // - // In the past we would reuse the actual receiver file descriptor and - // swallow any errors around double registration of the same descriptor. - // I'm not sure if the second (failed) registration simply doesn't end up - // receiving wake up notifications, or there could be some race condition - // when consuming readiness events, but having distinct descriptors for - // distinct PollEvented instances appears to mitigate this. - // - // Unfortunately we cannot just use a single global PollEvented instance - // either, since we can't compare Handles or assume they will always - // point to the exact same reactor. - let stream = globals().receiver.try_clone()?; - let wakeup = PollEvented::new(stream)?; - - Ok(Driver { wakeup }) - } - - /// Drain all data in the global receiver, ensuring we'll get woken up when - /// there is a write on the other end. - /// - /// We do *NOT* use the existence of any read bytes as evidence a signal was - /// received since the `pending` flags would have already been set if that - /// was the case. See - /// [#38](https://github.com/alexcrichton/tokio-signal/issues/38) for more - /// info. - fn drain(&mut self, cx: &mut Context<'_>) { - loop { - match Pin::new(&mut self.wakeup).poll_read(cx, &mut [0; 128]) { - Poll::Ready(Ok(0)) => panic!("EOF on self-pipe"), - Poll::Ready(Ok(_)) => {} - Poll::Ready(Err(e)) => panic!("Bad read on self-pipe: {}", e), - Poll::Pending => break, - } - } - } -} - /// A stream of events for receiving a particular type of OS signal. /// /// In general signal handling on Unix is a pretty tricky topic, and this @@ -376,7 +325,6 @@ impl Driver { #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct Signal { - driver: Driver, rx: Receiver<()>, } @@ -403,21 +351,21 @@ pub struct Signal { /// * If the signal is one of /// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics) pub fn signal(kind: SignalKind) -> io::Result<Signal> { + signal_with_handle(kind, Handle::current()) +} + +pub(crate) fn signal_with_handle(kind: SignalKind, handle: Handle) -> io::Result<Signal> { let signal = kind.0; // Turn the signal delivery on once we are ready for it - signal_enable(signal)?; - - // Ensure there's a driver for our associated event loop processing - // signals. - let driver = Driver::new()?; + signal_enable(signal, handle)?; // One wakeup in a queue is enough, no need for us to buffer up any // more. let (tx, rx) = channel(1); globals().register_listener(signal as EventId, tx); - Ok(Signal { driver, rx }) + Ok(Signal { rx }) } impl Signal { @@ -449,38 +397,14 @@ impl Signal { poll_fn(|cx| self.poll_recv(cx)).await } - /// Polls to receive the next signal notification event, outside of an - /// `async` context. - /// - /// `None` is returned if no more events can be received by this stream. - /// - /// # Examples - /// - /// Polling from a manually implemented future - /// - /// ```rust,no_run - /// use std::pin::Pin; - /// use std::future::Future; - /// use std::task::{Context, Poll}; - /// use tokio::signal::unix::Signal; - /// - /// struct MyFuture { - /// signal: Signal, - /// } - /// - /// impl Future for MyFuture { - /// type Output = Option<()>; - /// - /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - /// println!("polling MyFuture"); - /// self.signal.poll_recv(cx) - /// } - /// } - /// ``` - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - let _ = self.driver.poll(cx); + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { self.rx.poll_recv(cx) } + + /// Try to receive a signal notification without blocking or registering a waker. + pub(crate) fn try_recv(&mut self) -> Result<(), TryRecvError> { + self.rx.try_recv() + } } cfg_stream! { @@ -493,6 +417,22 @@ cfg_stream! { } } +// Work around for abstracting streams internally +pub(crate) trait InternalStream: Unpin { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>; + fn try_recv(&mut self) -> Result<(), TryRecvError>; +} + +impl InternalStream for Signal { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.poll_recv(cx) + } + + fn try_recv(&mut self) -> Result<(), TryRecvError> { + self.try_recv() + } +} + pub(crate) fn ctrl_c() -> io::Result<Signal> { signal(SignalKind::interrupt()) } @@ -503,11 +443,11 @@ mod tests { #[test] fn signal_enable_error_on_invalid_input() { - signal_enable(-1).unwrap_err(); + signal_enable(-1, Handle::default()).unwrap_err(); } #[test] fn signal_enable_error_on_forbidden_input() { - signal_enable(signal_hook_registry::FORBIDDEN[0]).unwrap_err(); + signal_enable(signal_hook_registry::FORBIDDEN[0], Handle::default()).unwrap_err(); } } diff --git a/src/signal/unix/driver.rs b/src/signal/unix/driver.rs new file mode 100644 index 0000000..8e5ed7d --- /dev/null +++ b/src/signal/unix/driver.rs @@ -0,0 +1,207 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +//! Signal driver + +use crate::io::driver::Driver as IoDriver; +use crate::io::PollEvented; +use crate::park::Park; +use crate::signal::registry::globals; + +use mio::net::UnixStream; +use std::io::{self, Read}; +use std::ptr; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; +use std::time::Duration; + +/// Responsible for registering wakeups when an OS signal is received, and +/// subsequently dispatching notifications to any signal listeners as appropriate. +/// +/// Note: this driver relies on having an enabled IO driver in order to listen to +/// pipe write wakeups. +#[derive(Debug)] +pub(crate) struct Driver { + /// Thread parker. The `Driver` park implementation delegates to this. + park: IoDriver, + + /// A pipe for receiving wake events from the signal handler + receiver: PollEvented<UnixStream>, + + /// Shared state + inner: Arc<Inner>, +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct Handle { + inner: Weak<Inner>, +} + +#[derive(Debug)] +pub(super) struct Inner(()); + +// ===== impl Driver ===== + +impl Driver { + /// Creates a new signal `Driver` instance that delegates wakeups to `park`. + pub(crate) fn new(park: IoDriver) -> io::Result<Self> { + use std::mem::ManuallyDrop; + use std::os::unix::io::{AsRawFd, FromRawFd}; + + // NB: We give each driver a "fresh" reciever file descriptor to avoid + // the issues described in alexcrichton/tokio-process#42. + // + // In the past we would reuse the actual receiver file descriptor and + // swallow any errors around double registration of the same descriptor. + // I'm not sure if the second (failed) registration simply doesn't end + // up receiving wake up notifications, or there could be some race + // condition when consuming readiness events, but having distinct + // descriptors for distinct PollEvented instances appears to mitigate + // this. + // + // Unfortunately we cannot just use a single global PollEvented instance + // either, since we can't compare Handles or assume they will always + // point to the exact same reactor. + // + // Mio 0.7 removed `try_clone()` as an API due to unexpected behavior + // with registering dups with the same reactor. In this case, duping is + // safe as each dup is registered with separate reactors **and** we + // only expect at least one dup to receive the notification. + + // Manually drop as we don't actually own this instance of UnixStream. + let receiver_fd = globals().receiver.as_raw_fd(); + + // safety: there is nothing unsafe about this, but the `from_raw_fd` fn is marked as unsafe. + let original = + ManuallyDrop::new(unsafe { std::os::unix::net::UnixStream::from_raw_fd(receiver_fd) }); + let receiver = UnixStream::from_std(original.try_clone()?); + let receiver = PollEvented::new_with_interest_and_handle( + receiver, + mio::Interest::READABLE | mio::Interest::WRITABLE, + park.handle(), + )?; + + Ok(Self { + park, + receiver, + inner: Arc::new(Inner(())), + }) + } + + /// Returns a handle to this event loop which can be sent across threads + /// and can be used as a proxy to the event loop itself. + pub(crate) fn handle(&self) -> Handle { + Handle { + inner: Arc::downgrade(&self.inner), + } + } + + fn process(&self) { + // Check if the pipe is ready to read and therefore has "woken" us up + // + // To do so, we will `poll_read_ready` with a noop waker, since we don't + // need to actually be notified when read ready... + let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let ev = match self.receiver.poll_read_ready(&mut cx) { + Poll::Ready(Ok(ev)) => ev, + Poll::Ready(Err(e)) => panic!("reactor gone: {}", e), + Poll::Pending => return, // No wake has arrived, bail + }; + + // Drain the pipe completely so we can receive a new readiness event + // if another signal has come in. + let mut buf = [0; 128]; + loop { + match self.receiver.get_ref().read(&mut buf) { + Ok(0) => panic!("EOF on self-pipe"), + Ok(_) => continue, // Keep reading + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => panic!("Bad read on self-pipe: {}", e), + } + } + + self.receiver.clear_readiness(ev); + + // Broadcast any signals which were received + globals().broadcast(); + } +} + +const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + +unsafe fn noop_clone(_data: *const ()) -> RawWaker { + RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE) +} + +unsafe fn noop(_data: *const ()) {} + +// ===== impl Park for Driver ===== + +impl Park for Driver { + type Unpark = <IoDriver as Park>::Unpark; + type Error = io::Error; + + fn unpark(&self) -> Self::Unpark { + self.park.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park.park()?; + self.process(); + Ok(()) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park.park_timeout(duration)?; + self.process(); + Ok(()) + } + + fn shutdown(&mut self) { + self.park.shutdown() + } +} + +// ===== impl Handle ===== + +impl Handle { + pub(super) fn check_inner(&self) -> io::Result<()> { + if self.inner.strong_count() > 0 { + Ok(()) + } else { + Err(io::Error::new(io::ErrorKind::Other, "signal driver gone")) + } + } +} + +cfg_rt! { + impl Handle { + /// Returns a handle to the current driver + /// + /// # Panics + /// + /// This function panics if there is no current signal driver set. + pub(super) fn current() -> Self { + crate::runtime::context::signal_handle().expect( + "there is no signal driver running, must be called from the context of Tokio runtime", + ) + } + } +} + +cfg_not_rt! { + impl Handle { + /// Returns a handle to the current driver + /// + /// # Panics + /// + /// This function panics if there is no current signal driver set. + pub(super) fn current() -> Self { + panic!( + "there is no signal driver running, must be called from the context of Tokio runtime or with\ + `rt` enabled.", + ) + } + } +} diff --git a/src/signal/windows.rs b/src/signal/windows.rs index f55e504..1e78362 100644 --- a/src/signal/windows.rs +++ b/src/signal/windows.rs @@ -14,9 +14,9 @@ use std::convert::TryFrom; use std::io; use std::sync::Once; use std::task::{Context, Poll}; -use winapi::shared::minwindef::*; +use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; use winapi::um::consoleapi::SetConsoleCtrlHandler; -use winapi::um::wincon::*; +use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; #[derive(Debug)] pub(crate) struct OsStorage { @@ -253,26 +253,25 @@ mod tests { #[test] fn ctrl_c() { let rt = rt(); + let _enter = rt.enter(); - rt.enter(|| { - let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); + let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); - assert_pending!(ctrl_c.poll()); + assert_pending!(ctrl_c.poll()); - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_C_EVENT); - } + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_C_EVENT); + } - assert_ready_ok!(ctrl_c.poll()); - }); + assert_ready_ok!(ctrl_c.poll()); } #[test] fn ctrl_break() { - let mut rt = rt(); + let rt = rt(); rt.block_on(async { let mut ctrl_break = assert_ok!(super::ctrl_break()); @@ -289,8 +288,7 @@ mod tests { } fn rt() -> Runtime { - crate::runtime::Builder::new() - .basic_scheduler() + crate::runtime::Builder::new_current_thread() .build() .unwrap() } diff --git a/src/stream/all.rs b/src/stream/all.rs index 615665d..353d61a 100644 --- a/src/stream/all.rs +++ b/src/stream/all.rs @@ -1,25 +1,34 @@ use crate::stream::Stream; use core::future::Future; +use core::marker::PhantomPinned; use core::pin::Pin; use core::task::{Context, Poll}; +use pin_project_lite::pin_project; -/// Future for the [`all`](super::StreamExt::all) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct AllFuture<'a, St: ?Sized, F> { - stream: &'a mut St, - f: F, +pin_project! { + /// Future for the [`all`](super::StreamExt::all) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct AllFuture<'a, St: ?Sized, F> { + stream: &'a mut St, + f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } impl<'a, St: ?Sized, F> AllFuture<'a, St, F> { pub(super) fn new(stream: &'a mut St, f: F) -> Self { - Self { stream, f } + Self { + stream, + f, + _pin: PhantomPinned, + } } } -impl<St: ?Sized + Unpin, F> Unpin for AllFuture<'_, St, F> {} - impl<St, F> Future for AllFuture<'_, St, F> where St: ?Sized + Stream + Unpin, @@ -27,12 +36,13 @@ where { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let next = futures_core::ready!(Pin::new(&mut self.stream).poll_next(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx)); match next { Some(v) => { - if !(&mut self.f)(v) { + if !(me.f)(v) { Poll::Ready(false) } else { cx.waker().wake_by_ref(); diff --git a/src/stream/any.rs b/src/stream/any.rs index f2ecad5..aac0ec7 100644 --- a/src/stream/any.rs +++ b/src/stream/any.rs @@ -1,25 +1,34 @@ use crate::stream::Stream; use core::future::Future; +use core::marker::PhantomPinned; use core::pin::Pin; use core::task::{Context, Poll}; +use pin_project_lite::pin_project; -/// Future for the [`any`](super::StreamExt::any) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct AnyFuture<'a, St: ?Sized, F> { - stream: &'a mut St, - f: F, +pin_project! { + /// Future for the [`any`](super::StreamExt::any) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct AnyFuture<'a, St: ?Sized, F> { + stream: &'a mut St, + f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } impl<'a, St: ?Sized, F> AnyFuture<'a, St, F> { pub(super) fn new(stream: &'a mut St, f: F) -> Self { - Self { stream, f } + Self { + stream, + f, + _pin: PhantomPinned, + } } } -impl<St: ?Sized + Unpin, F> Unpin for AnyFuture<'_, St, F> {} - impl<St, F> Future for AnyFuture<'_, St, F> where St: ?Sized + Stream + Unpin, @@ -27,12 +36,13 @@ where { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let next = futures_core::ready!(Pin::new(&mut self.stream).poll_next(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx)); match next { Some(v) => { - if (&mut self.f)(v) { + if (me.f)(v) { Poll::Ready(true) } else { cx.waker().wake_by_ref(); diff --git a/src/stream/collect.rs b/src/stream/collect.rs index 4649428..1aafc30 100644 --- a/src/stream/collect.rs +++ b/src/stream/collect.rs @@ -1,7 +1,7 @@ use crate::stream::Stream; -use bytes::{Buf, BufMut, Bytes, BytesMut}; use core::future::Future; +use core::marker::PhantomPinned; use core::mem; use core::pin::Pin; use core::task::{Context, Poll}; @@ -10,7 +10,7 @@ use pin_project_lite::pin_project; // Do not export this struct until `FromStream` can be unsealed. pin_project! { /// Future returned by the [`collect`](super::StreamExt::collect) method. - #[must_use = "streams do nothing unless polled"] + #[must_use = "futures do nothing unless you `.await` or poll them"] #[derive(Debug)] pub struct Collect<T, U> where @@ -19,7 +19,10 @@ pin_project! { { #[pin] stream: T, - collection: U::Collection, + collection: U::InternalCollection, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -42,9 +45,13 @@ where { pub(super) fn new(stream: T) -> Collect<T, U> { let (lower, upper) = stream.size_hint(); - let collection = U::initialize(lower, upper); + let collection = U::initialize(sealed::Internal, lower, upper); - Collect { stream, collection } + Collect { + stream, + collection, + _pin: PhantomPinned, + } } } @@ -64,12 +71,12 @@ where let item = match ready!(me.stream.poll_next(cx)) { Some(item) => item, None => { - return Ready(U::finalize(&mut me.collection)); + return Ready(U::finalize(sealed::Internal, &mut me.collection)); } }; - if !U::extend(&mut me.collection, item) { - return Ready(U::finalize(&mut me.collection)); + if !U::extend(sealed::Internal, &mut me.collection, item) { + return Ready(U::finalize(sealed::Internal, &mut me.collection)); } } } @@ -80,32 +87,32 @@ where impl FromStream<()> for () {} impl sealed::FromStreamPriv<()> for () { - type Collection = (); + type InternalCollection = (); - fn initialize(_lower: usize, _upper: Option<usize>) {} + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) {} - fn extend(_collection: &mut (), _item: ()) -> bool { + fn extend(_: sealed::Internal, _collection: &mut (), _item: ()) -> bool { true } - fn finalize(_collection: &mut ()) {} + fn finalize(_: sealed::Internal, _collection: &mut ()) {} } impl<T: AsRef<str>> FromStream<T> for String {} impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String { - type Collection = String; + type InternalCollection = String; - fn initialize(_lower: usize, _upper: Option<usize>) -> String { + fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) -> String { String::new() } - fn extend(collection: &mut String, item: T) -> bool { + fn extend(_: sealed::Internal, collection: &mut String, item: T) -> bool { collection.push_str(item.as_ref()); true } - fn finalize(collection: &mut String) -> String { + fn finalize(_: sealed::Internal, collection: &mut String) -> String { mem::replace(collection, String::new()) } } @@ -113,18 +120,18 @@ impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String { impl<T> FromStream<T> for Vec<T> {} impl<T> sealed::FromStreamPriv<T> for Vec<T> { - type Collection = Vec<T>; + type InternalCollection = Vec<T>; - fn initialize(lower: usize, _upper: Option<usize>) -> Vec<T> { + fn initialize(_: sealed::Internal, lower: usize, _upper: Option<usize>) -> Vec<T> { Vec::with_capacity(lower) } - fn extend(collection: &mut Vec<T>, item: T) -> bool { + fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool { collection.push(item); true } - fn finalize(collection: &mut Vec<T>) -> Vec<T> { + fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Vec<T> { mem::replace(collection, vec![]) } } @@ -132,18 +139,19 @@ impl<T> sealed::FromStreamPriv<T> for Vec<T> { impl<T> FromStream<T> for Box<[T]> {} impl<T> sealed::FromStreamPriv<T> for Box<[T]> { - type Collection = Vec<T>; + type InternalCollection = Vec<T>; - fn initialize(lower: usize, upper: Option<usize>) -> Vec<T> { - <Vec<T> as sealed::FromStreamPriv<T>>::initialize(lower, upper) + fn initialize(_: sealed::Internal, lower: usize, upper: Option<usize>) -> Vec<T> { + <Vec<T> as sealed::FromStreamPriv<T>>::initialize(sealed::Internal, lower, upper) } - fn extend(collection: &mut Vec<T>, item: T) -> bool { - <Vec<T> as sealed::FromStreamPriv<T>>::extend(collection, item) + fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool { + <Vec<T> as sealed::FromStreamPriv<T>>::extend(sealed::Internal, collection, item) } - fn finalize(collection: &mut Vec<T>) -> Box<[T]> { - <Vec<T> as sealed::FromStreamPriv<T>>::finalize(collection).into_boxed_slice() + fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Box<[T]> { + <Vec<T> as sealed::FromStreamPriv<T>>::finalize(sealed::Internal, collection) + .into_boxed_slice() } } @@ -153,18 +161,26 @@ impl<T, U, E> sealed::FromStreamPriv<Result<T, E>> for Result<U, E> where U: FromStream<T>, { - type Collection = Result<U::Collection, E>; + type InternalCollection = Result<U::InternalCollection, E>; - fn initialize(lower: usize, upper: Option<usize>) -> Result<U::Collection, E> { - Ok(U::initialize(lower, upper)) + fn initialize( + _: sealed::Internal, + lower: usize, + upper: Option<usize>, + ) -> Result<U::InternalCollection, E> { + Ok(U::initialize(sealed::Internal, lower, upper)) } - fn extend(collection: &mut Self::Collection, item: Result<T, E>) -> bool { + fn extend( + _: sealed::Internal, + collection: &mut Self::InternalCollection, + item: Result<T, E>, + ) -> bool { assert!(collection.is_ok()); match item { Ok(item) => { let collection = collection.as_mut().ok().expect("invalid state"); - U::extend(collection, item) + U::extend(sealed::Internal, collection, item) } Err(err) => { *collection = Err(err); @@ -173,11 +189,11 @@ where } } - fn finalize(collection: &mut Self::Collection) -> Result<U, E> { + fn finalize(_: sealed::Internal, collection: &mut Self::InternalCollection) -> Result<U, E> { if let Ok(collection) = collection.as_mut() { - Ok(U::finalize(collection)) + Ok(U::finalize(sealed::Internal, collection)) } else { - let res = mem::replace(collection, Ok(U::initialize(0, Some(0)))); + let res = mem::replace(collection, Ok(U::initialize(sealed::Internal, 0, Some(0)))); if let Err(err) = res { Err(err) @@ -188,59 +204,30 @@ where } } -impl<T: Buf> FromStream<T> for Bytes {} - -impl<T: Buf> sealed::FromStreamPriv<T> for Bytes { - type Collection = BytesMut; - - fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut { - BytesMut::new() - } - - fn extend(collection: &mut BytesMut, item: T) -> bool { - collection.put(item); - true - } - - fn finalize(collection: &mut BytesMut) -> Bytes { - mem::replace(collection, BytesMut::new()).freeze() - } -} - -impl<T: Buf> FromStream<T> for BytesMut {} - -impl<T: Buf> sealed::FromStreamPriv<T> for BytesMut { - type Collection = BytesMut; - - fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut { - BytesMut::new() - } - - fn extend(collection: &mut BytesMut, item: T) -> bool { - collection.put(item); - true - } - - fn finalize(collection: &mut BytesMut) -> BytesMut { - mem::replace(collection, BytesMut::new()) - } -} - pub(crate) mod sealed { #[doc(hidden)] pub trait FromStreamPriv<T> { /// Intermediate type used during collection process - type Collection; + /// + /// The name of this type is internal and cannot be relied upon. + type InternalCollection; /// Initialize the collection - fn initialize(lower: usize, upper: Option<usize>) -> Self::Collection; + fn initialize( + internal: Internal, + lower: usize, + upper: Option<usize>, + ) -> Self::InternalCollection; /// Extend the collection with the received item /// /// Return `true` to continue streaming, `false` complete collection. - fn extend(collection: &mut Self::Collection, item: T) -> bool; + fn extend(internal: Internal, collection: &mut Self::InternalCollection, item: T) -> bool; /// Finalize collection into target type. - fn finalize(collection: &mut Self::Collection) -> Self; + fn finalize(internal: Internal, collection: &mut Self::InternalCollection) -> Self; } + + #[allow(missing_debug_implementations)] + pub struct Internal; } diff --git a/src/stream/fold.rs b/src/stream/fold.rs index 7b9fead..5cf2bfa 100644 --- a/src/stream/fold.rs +++ b/src/stream/fold.rs @@ -1,6 +1,7 @@ use crate::stream::Stream; use core::future::Future; +use core::marker::PhantomPinned; use core::pin::Pin; use core::task::{Context, Poll}; use pin_project_lite::pin_project; @@ -8,11 +9,15 @@ use pin_project_lite::pin_project; pin_project! { /// Future returned by the [`fold`](super::StreamExt::fold) method. #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct FoldFuture<St, B, F> { #[pin] stream: St, acc: Option<B>, f: F, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -22,6 +27,7 @@ impl<St, B, F> FoldFuture<St, B, F> { stream, acc: Some(init), f, + _pin: PhantomPinned, } } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 7b061ef..6bf4232 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,8 +1,57 @@ //! Stream utilities for Tokio. //! -//! A `Stream` is an asynchronous sequence of values. It can be thought of as an asynchronous version of the standard library's `Iterator` trait. +//! A `Stream` is an asynchronous sequence of values. It can be thought of as +//! an asynchronous version of the standard library's `Iterator` trait. //! -//! This module provides helpers to work with them. +//! This module provides helpers to work with them. For examples of usage and a more in-depth +//! description of streams you can also refer to the [streams +//! tutorial](https://tokio.rs/tokio/tutorial/streams) on the tokio website. +//! +//! # Iterating over a Stream +//! +//! Due to similarities with the standard library's `Iterator` trait, some new +//! users may assume that they can use `for in` syntax to iterate over a +//! `Stream`, but this is unfortunately not possible. Instead, you can use a +//! `while let` loop as follows: +//! +//! ```rust +//! use tokio::stream::{self, StreamExt}; +//! +//! #[tokio::main] +//! async fn main() { +//! let mut stream = stream::iter(vec![0, 1, 2]); +//! +//! while let Some(value) = stream.next().await { +//! println!("Got {}", value); +//! } +//! } +//! ``` +//! +//! # Returning a Stream from a function +//! +//! A common way to stream values from a function is to pass in the sender +//! half of a channel and use the receiver as the stream. This requires awaiting +//! both futures to ensure progress is made. Another alternative is the +//! [async-stream] crate, which contains macros that provide a `yield` keyword +//! and allow you to return an `impl Stream`. +//! +//! [async-stream]: https://docs.rs/async-stream +//! +//! # Conversion to and from AsyncRead/AsyncWrite +//! +//! It is often desirable to convert a `Stream` into an [`AsyncRead`], +//! especially when dealing with plaintext formats streamed over the network. +//! The opposite conversion from an [`AsyncRead`] into a `Stream` is also +//! another commonly required feature. To enable these conversions, +//! [`tokio-util`] provides the [`StreamReader`] and [`ReaderStream`] +//! types when the io feature is enabled. +//! +//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! [`tokio::io`]: crate::io +//! [`AsyncRead`]: crate::io::AsyncRead +//! [`AsyncWrite`]: crate::io::AsyncWrite +//! [`ReaderStream`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.ReaderStream.html +//! [`StreamReader`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.StreamReader.html mod all; use all::AllFuture; @@ -71,9 +120,12 @@ use take_while::TakeWhile; cfg_time! { mod timeout; use timeout::Timeout; - use std::time::Duration; + use crate::time::Duration; + mod throttle; + use crate::stream::throttle::{throttle, Throttle}; } +#[doc(no_inline)] pub use futures_core::Stream; /// An extension trait for `Stream`s that provides a variety of convenient @@ -215,11 +267,11 @@ pub trait StreamExt: Stream { /// # /* /// #[tokio::main] /// # */ - /// # #[tokio::main(basic_scheduler)] + /// # #[tokio::main(flavor = "current_thread")] /// async fn main() { /// # time::pause(); - /// let (mut tx1, rx1) = mpsc::channel(10); - /// let (mut tx2, rx2) = mpsc::channel(10); + /// let (tx1, rx1) = mpsc::channel(10); + /// let (tx2, rx2) = mpsc::channel(10); /// /// let mut rx = rx1.merge(rx2); /// @@ -229,18 +281,18 @@ pub trait StreamExt: Stream { /// tx1.send(2).await.unwrap(); /// /// // Let the other task send values - /// time::delay_for(Duration::from_millis(20)).await; + /// time::sleep(Duration::from_millis(20)).await; /// /// tx1.send(4).await.unwrap(); /// }); /// /// tokio::spawn(async move { /// // Wait for the first task to send values - /// time::delay_for(Duration::from_millis(5)).await; + /// time::sleep(Duration::from_millis(5)).await; /// /// tx2.send(3).await.unwrap(); /// - /// time::delay_for(Duration::from_millis(25)).await; + /// time::sleep(Duration::from_millis(25)).await; /// /// // Send the final value /// tx2.send(5).await.unwrap(); @@ -520,6 +572,12 @@ pub trait StreamExt: Stream { /// Tests if every element of the stream matches a predicate. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn all<F>(&mut self, f: F) -> bool; + /// ``` + /// /// `all()` takes a closure that returns `true` or `false`. It applies /// this closure to each element of the stream, and if they all return /// `true`, then so does `all`. If any of them return `false`, it @@ -575,6 +633,12 @@ pub trait StreamExt: Stream { /// Tests if any element of the stream matches a predicate. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn any<F>(&mut self, f: F) -> bool; + /// ``` + /// /// `any()` takes a closure that returns `true` or `false`. It applies /// this closure to each element of the stream, and if any of them return /// `true`, then so does `any()`. If they all return `false`, it @@ -664,6 +728,12 @@ pub trait StreamExt: Stream { /// A combinator that applies a function to every element in a stream /// producing a single, final value. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn fold<B, F>(self, init: B, f: F) -> B; + /// ``` + /// /// # Examples /// Basic usage: /// ``` @@ -687,6 +757,12 @@ pub trait StreamExt: Stream { /// Drain stream pushing all emitted values into a collection. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn collect<T>(self) -> T; + /// ``` + /// /// `collect` streams all values, awaiting as needed. Values are pushed into /// a collection. A number of different target collection types are /// supported, including [`Vec`](std::vec::Vec), @@ -819,6 +895,33 @@ pub trait StreamExt: Stream { { Timeout::new(self, duration) } + + /// Slows down a stream by enforcing a delay between items. + /// + /// # Example + /// + /// Create a throttled stream. + /// ```rust,no_run + /// use std::time::Duration; + /// use tokio::stream::StreamExt; + /// + /// # async fn dox() { + /// let mut item_stream = futures::stream::repeat("one").throttle(Duration::from_secs(2)); + /// + /// loop { + /// // The string will be produced at most every 2 seconds + /// println!("{:?}", item_stream.next().await); + /// } + /// # } + /// ``` + #[cfg(all(feature = "time"))] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + fn throttle(self, duration: Duration) -> Throttle<Self> + where + Self: Sized, + { + throttle(duration, self) + } } impl<St: ?Sized> StreamExt for St where St: Stream {} diff --git a/src/stream/next.rs b/src/stream/next.rs index 3909c0c..d9b1f92 100644 --- a/src/stream/next.rs +++ b/src/stream/next.rs @@ -1,28 +1,37 @@ use crate::stream::Stream; use core::future::Future; +use core::marker::PhantomPinned; use core::pin::Pin; use core::task::{Context, Poll}; +use pin_project_lite::pin_project; -/// Future for the [`next`](super::StreamExt::next) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Next<'a, St: ?Sized> { - stream: &'a mut St, +pin_project! { + /// Future for the [`next`](super::StreamExt::next) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Next<'a, St: ?Sized> { + stream: &'a mut St, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } -impl<St: ?Sized + Unpin> Unpin for Next<'_, St> {} - impl<'a, St: ?Sized> Next<'a, St> { pub(super) fn new(stream: &'a mut St) -> Self { - Next { stream } + Next { + stream, + _pin: PhantomPinned, + } } } impl<St: ?Sized + Stream + Unpin> Future for Next<'_, St> { type Output = Option<St::Item>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - Pin::new(&mut self.stream).poll_next(cx) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(me.stream).poll_next(cx) } } diff --git a/src/stream/stream_map.rs b/src/stream/stream_map.rs index 2f60ea4..8539e4d 100644 --- a/src/stream/stream_map.rs +++ b/src/stream/stream_map.rs @@ -57,8 +57,8 @@ use std::task::{Context, Poll}; /// /// #[tokio::main] /// async fn main() { -/// let (mut tx1, rx1) = mpsc::channel(10); -/// let (mut tx2, rx2) = mpsc::channel(10); +/// let (tx1, rx1) = mpsc::channel(10); +/// let (tx2, rx2) = mpsc::channel(10); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); @@ -163,6 +163,52 @@ pub struct StreamMap<K, V> { } impl<K, V> StreamMap<K, V> { + /// An iterator visiting all key-value pairs in arbitrary order. + /// + /// The iterator element type is &'a (K, V). + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::{StreamMap, pending}; + /// + /// let mut map = StreamMap::new(); + /// + /// map.insert("a", pending::<i32>()); + /// map.insert("b", pending()); + /// map.insert("c", pending()); + /// + /// for (key, stream) in map.iter() { + /// println!("({}, {:?})", key, stream); + /// } + /// ``` + pub fn iter(&self) -> impl Iterator<Item = &(K, V)> { + self.entries.iter() + } + + /// An iterator visiting all key-value pairs mutably in arbitrary order. + /// + /// The iterator element type is &'a mut (K, V). + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::{StreamMap, pending}; + /// + /// let mut map = StreamMap::new(); + /// + /// map.insert("a", pending::<i32>()); + /// map.insert("b", pending()); + /// map.insert("c", pending()); + /// + /// for (key, stream) in map.iter_mut() { + /// println!("({}, {:?})", key, stream); + /// } + /// ``` + pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (K, V)> { + self.entries.iter_mut() + } + /// Creates an empty `StreamMap`. /// /// The stream map is initially created with a capacity of `0`, so it will @@ -217,7 +263,7 @@ impl<K, V> StreamMap<K, V> { /// } /// ``` pub fn keys(&self) -> impl Iterator<Item = &K> { - self.entries.iter().map(|(k, _)| k) + self.iter().map(|(k, _)| k) } /// An iterator visiting all values in arbitrary order. @@ -240,7 +286,7 @@ impl<K, V> StreamMap<K, V> { /// } /// ``` pub fn values(&self) -> impl Iterator<Item = &V> { - self.entries.iter().map(|(_, v)| v) + self.iter().map(|(_, v)| v) } /// An iterator visiting all values mutably in arbitrary order. @@ -263,7 +309,7 @@ impl<K, V> StreamMap<K, V> { /// } /// ``` pub fn values_mut(&mut self) -> impl Iterator<Item = &mut V> { - self.entries.iter_mut().map(|(_, v)| v) + self.iter_mut().map(|(_, v)| v) } /// Returns the number of streams the map can hold without reallocating. diff --git a/src/time/throttle.rs b/src/stream/throttle.rs index d53a6f7..8f4a256 100644 --- a/src/time/throttle.rs +++ b/src/stream/throttle.rs @@ -1,7 +1,7 @@ //! Slow down a stream by enforcing a delay between items. use crate::stream::Stream; -use crate::time::{Delay, Duration, Instant}; +use crate::time::{Duration, Instant, Sleep}; use std::future::Future; use std::marker::Unpin; @@ -10,34 +10,14 @@ use std::task::{self, Poll}; use pin_project_lite::pin_project; -/// Slows down a stream by enforcing a delay between items. -/// They will be produced not more often than the specified interval. -/// -/// # Example -/// -/// Create a throttled stream. -/// ```rust,no_run -/// use std::time::Duration; -/// use tokio::stream::StreamExt; -/// use tokio::time::throttle; -/// -/// # async fn dox() { -/// let mut item_stream = throttle(Duration::from_secs(2), futures::stream::repeat("one")); -/// -/// loop { -/// // The string will be produced at most every 2 seconds -/// println!("{:?}", item_stream.next().await); -/// } -/// # } -/// ``` -pub fn throttle<T>(duration: Duration, stream: T) -> Throttle<T> +pub(super) fn throttle<T>(duration: Duration, stream: T) -> Throttle<T> where T: Stream, { let delay = if duration == Duration::from_millis(0) { None } else { - Some(Delay::new_timeout(Instant::now() + duration, duration)) + Some(Sleep::new_timeout(Instant::now() + duration, duration)) }; Throttle { @@ -54,7 +34,7 @@ pin_project! { #[must_use = "streams do nothing unless polled"] pub struct Throttle<T> { // `None` when duration is zero. - delay: Option<Delay>, + delay: Option<Sleep>, duration: Duration, // Set to true when `delay` has returned ready, but `stream` hasn't. diff --git a/src/stream/timeout.rs b/src/stream/timeout.rs index b8a2024..669973f 100644 --- a/src/stream/timeout.rs +++ b/src/stream/timeout.rs @@ -1,5 +1,5 @@ use crate::stream::{Fuse, Stream}; -use crate::time::{Delay, Elapsed, Instant}; +use crate::time::{error::Elapsed, Instant, Sleep}; use core::future::Future; use core::pin::Pin; @@ -14,7 +14,7 @@ pin_project! { pub struct Timeout<S> { #[pin] stream: Fuse<S>, - deadline: Delay, + deadline: Sleep, duration: Duration, poll_deadline: bool, } @@ -23,7 +23,7 @@ pin_project! { impl<S: Stream> Timeout<S> { pub(super) fn new(stream: S, duration: Duration) -> Self { let next = Instant::now() + duration; - let deadline = Delay::new_timeout(next, duration); + let deadline = Sleep::new_timeout(next, duration); Timeout { stream: Fuse::new(stream), diff --git a/src/stream/try_next.rs b/src/stream/try_next.rs index 59e0eb1..b21d279 100644 --- a/src/stream/try_next.rs +++ b/src/stream/try_next.rs @@ -1,22 +1,29 @@ use crate::stream::{Next, Stream}; use core::future::Future; +use core::marker::PhantomPinned; use core::pin::Pin; use core::task::{Context, Poll}; +use pin_project_lite::pin_project; -/// Future for the [`try_next`](super::StreamExt::try_next) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct TryNext<'a, St: ?Sized> { - inner: Next<'a, St>, +pin_project! { + /// Future for the [`try_next`](super::StreamExt::try_next) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct TryNext<'a, St: ?Sized> { + #[pin] + inner: Next<'a, St>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } -impl<St: ?Sized + Unpin> Unpin for TryNext<'_, St> {} - impl<'a, St: ?Sized> TryNext<'a, St> { pub(super) fn new(stream: &'a mut St) -> Self { Self { inner: Next::new(stream), + _pin: PhantomPinned, } } } @@ -24,7 +31,8 @@ impl<'a, St: ?Sized> TryNext<'a, St> { impl<T, E, St: ?Sized + Stream<Item = Result<T, E>> + Unpin> Future for TryNext<'_, St> { type Output = Result<Option<T>, E>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - Pin::new(&mut self.inner).poll(cx).map(Option::transpose) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + me.inner.poll(cx).map(Option::transpose) } } diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs index 6286334..fddb3a5 100644 --- a/src/sync/barrier.rs +++ b/src/sync/barrier.rs @@ -96,7 +96,7 @@ impl Barrier { // wake everyone, increment the generation, and return state .waker - .broadcast(state.generation) + .send(state.generation) .expect("there is at least one receiver"); state.arrived = 0; state.generation += 1; @@ -110,9 +110,11 @@ impl Barrier { let mut wait = self.wait.clone(); loop { + let _ = wait.changed().await; + // note that the first time through the loop, this _will_ yield a generation // immediately, since we cloned a receiver that has never seen any values. - if wait.recv().await.expect("sender hasn't been closed") >= generation { + if *wait.borrow() >= generation { break; } } diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index 070cd20..0b50e4f 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -1,3 +1,4 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] //! # Implementation Details //! //! The semaphore is implemented using an intrusive linked list of waiters. An @@ -36,7 +37,7 @@ pub(crate) struct Semaphore { } struct Waitlist { - queue: LinkedList<Waiter>, + queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, closed: bool, } @@ -96,10 +97,13 @@ impl Semaphore { /// Note that this reserves three bits of flags in the permit counter, but /// we only actually use one of them. However, the previous semaphore /// implementation used three bits, so we will continue to reserve them to - /// avoid a breaking change if additional flags need to be aadded in the + /// avoid a breaking change if additional flags need to be added in the /// future. pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3; const CLOSED: usize = 1; + // The least-significant bit in the number of permits is reserved to use + // as a flag indicating that the semaphore has been closed. Consequently + // PERMIT_SHIFT is used to leave that bit for that purpose. const PERMIT_SHIFT: usize = 1; /// Creates a new semaphore with the initial number of permits @@ -120,6 +124,27 @@ impl Semaphore { } } + /// Creates a new semaphore with the initial number of permits + /// + /// Maximum number of permits on 32-bit platforms is `1<<29`. + /// + /// If the specified number of permits exceeds the maximum permit amount + /// Then the value will get clamped to the maximum number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + pub(crate) const fn const_new(mut permits: usize) -> Self { + // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925 + // currently we just clamp the permit count when it exceeds the max + permits &= Self::MAX_PERMITS; + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::const_new(Waitlist { + queue: LinkedList::new(), + closed: false, + }), + } + } + /// Returns the current number of available permits pub(crate) fn available_permits(&self) -> usize { self.permits.load(Acquire) >> Self::PERMIT_SHIFT @@ -134,16 +159,15 @@ impl Semaphore { } // Assign permits to the wait queue - self.add_permits_locked(added, self.waiters.lock().unwrap()); + self.add_permits_locked(added, self.waiters.lock()); } /// Closes the semaphore. This prevents the semaphore from issuing new /// permits and notifies all pending waiters. // This will be used once the bounded MPSC is updated to use the new // semaphore implementation. - #[allow(dead_code)] pub(crate) fn close(&self) { - let mut waiters = self.waiters.lock().unwrap(); + let mut waiters = self.waiters.lock(); // If the semaphore's permits counter has enough permits for an // unqueued waiter to acquire all the permits it needs immediately, // it won't touch the wait list. Therefore, we have to set a bit on @@ -161,6 +185,11 @@ impl Semaphore { } } + /// Returns true if the semaphore is closed + pub(crate) fn is_closed(&self) -> bool { + self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED + } + pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { assert!( num_permits as usize <= Self::MAX_PERMITS, @@ -170,8 +199,8 @@ impl Semaphore { let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; let mut curr = self.permits.load(Acquire); loop { - // Has the semaphore closed?git - if curr & Self::CLOSED > 0 { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { return Err(TryAcquireError::Closed); } @@ -203,7 +232,7 @@ impl Semaphore { let mut lock = Some(waiters); let mut is_empty = false; while rem > 0 { - let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap()); + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); 'inner: for slot in &mut wakers[..] { // Was the waiter assigned enough permits to wake it? match waiters.queue.last() { @@ -296,7 +325,7 @@ impl Semaphore { // counter. Otherwise, if we subtract the permits and then // acquire the lock, we might miss additional permits being // added while waiting for the lock. - lock = Some(self.waiters.lock().unwrap()); + lock = Some(self.waiters.lock()); } match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { @@ -306,7 +335,7 @@ impl Semaphore { if !queued { return Ready(Ok(())); } else if lock.is_none() { - break self.waiters.lock().unwrap(); + break self.waiters.lock(); } } break lock.expect("lock must be acquired before waiting"); @@ -357,7 +386,7 @@ impl Semaphore { impl fmt::Debug for Semaphore { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Semaphore") - .field("permits", &self.permits.load(Relaxed)) + .field("permits", &self.available_permits()) .finish() } } @@ -456,14 +485,7 @@ impl Drop for Acquire<'_> { // This is where we ensure safety. The future is being dropped, // which means we must ensure that the waiter entry is no longer stored // in the linked list. - let mut waiters = match self.semaphore.waiters.lock() { - Ok(lock) => lock, - // Removing the node from the linked list is necessary to ensure - // safety. Even if the lock was poisoned, we need to make sure it is - // removed from the linked list before dropping it --- otherwise, - // the list will contain a dangling pointer to this node. - Err(e) => e.into_inner(), - }; + let mut waiters = self.semaphore.waiters.lock(); // remove the entry from the list let node = NonNull::from(&mut self.node); @@ -506,20 +528,14 @@ impl TryAcquireError { /// Returns `true` if the error was caused by a closed semaphore. #[allow(dead_code)] // may be used later! pub(crate) fn is_closed(&self) -> bool { - match self { - TryAcquireError::Closed => true, - _ => false, - } + matches!(self, TryAcquireError::Closed) } /// Returns `true` if the error was caused by calling `try_acquire` on a /// semaphore with no available permits. #[allow(dead_code)] // may be used later! pub(crate) fn is_no_permits(&self) -> bool { - match self { - TryAcquireError::NoPermits => true, - _ => false, - } + matches!(self, TryAcquireError::NoPermits) } } diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index 0c8716f..ee9aba0 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -21,7 +21,7 @@ //! ## Lagging //! //! As sent messages must be retained until **all** [`Receiver`] handles receive -//! a clone, broadcast channels are suspectible to the "slow receiver" problem. +//! a clone, broadcast channels are susceptible to the "slow receiver" problem. //! In this case, all but one receiver are able to receive values at the rate //! they are sent. Because one receiver is stalled, the channel starts to fill //! up. @@ -55,8 +55,8 @@ //! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe //! [`Receiver`]: crate::sync::broadcast::Receiver //! [`channel`]: crate::sync::broadcast::channel -//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged -//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed +//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged +//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed //! [`recv`]: crate::sync::broadcast::Receiver::recv //! //! # Examples @@ -107,6 +107,7 @@ //! assert_eq!(20, rx.recv().await.unwrap()); //! assert_eq!(30, rx.recv().await.unwrap()); //! } +//! ``` use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; @@ -194,58 +195,99 @@ pub struct Receiver<T> { /// Next position to read from next: u64, - - /// Used to support the deprecated `poll_recv` fn - waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, } -/// Error returned by [`Sender::send`][Sender::send]. -/// -/// A **send** operation can only fail if there are no active receivers, -/// implying that the message could never be received. The error contains the -/// message being sent as a payload so it can be recovered. -#[derive(Debug)] -pub struct SendError<T>(pub T); +pub mod error { + //! Broadcast error types -/// An error returned from the [`recv`] function on a [`Receiver`]. -/// -/// [`recv`]: crate::sync::broadcast::Receiver::recv -/// [`Receiver`]: crate::sync::broadcast::Receiver -#[derive(Debug, PartialEq)] -pub enum RecvError { - /// There are no more active senders implying no further messages will ever - /// be sent. - Closed, + use std::fmt; - /// The receiver lagged too far behind. Attempting to receive again will - /// return the oldest message still retained by the channel. + /// Error returned by from the [`send`] function on a [`Sender`]. /// - /// Includes the number of skipped messages. - Lagged(u64), -} + /// A **send** operation can only fail if there are no active receivers, + /// implying that the message could never be received. The error contains the + /// message being sent as a payload so it can be recovered. + /// + /// [`send`]: crate::sync::broadcast::Sender::send + /// [`Sender`]: crate::sync::broadcast::Sender + #[derive(Debug)] + pub struct SendError<T>(pub T); -/// An error returned from the [`try_recv`] function on a [`Receiver`]. -/// -/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv -/// [`Receiver`]: crate::sync::broadcast::Receiver -#[derive(Debug, PartialEq)] -pub enum TryRecvError { - /// The channel is currently empty. There are still active - /// [`Sender`][Sender] handles, so data may yet become available. - Empty, - - /// There are no more active senders implying no further messages will ever - /// be sent. - Closed, - - /// The receiver lagged too far behind and has been forcibly disconnected. - /// Attempting to receive again will return the oldest message still - /// retained by the channel. - /// - /// Includes the number of skipped messages. - Lagged(u64), + impl<T> fmt::Display for SendError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "channel closed") + } + } + + impl<T: fmt::Debug> std::error::Error for SendError<T> {} + + /// An error returned from the [`recv`] function on a [`Receiver`]. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum RecvError { + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvError::Closed => write!(f, "channel closed"), + RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for RecvError {} + + /// An error returned from the [`try_recv`] function on a [`Receiver`]. + /// + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + #[derive(Debug, PartialEq)] + pub enum TryRecvError { + /// The channel is currently empty. There are still active + /// [`Sender`] handles, so data may yet become available. + /// + /// [`Sender`]: crate::sync::broadcast::Sender + Empty, + + /// There are no more active senders implying no further messages will ever + /// be sent. + Closed, + + /// The receiver lagged too far behind and has been forcibly disconnected. + /// Attempting to receive again will return the oldest message still + /// retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), + } + + impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TryRecvError::Empty => write!(f, "channel empty"), + TryRecvError::Closed => write!(f, "channel closed"), + TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), + } + } + } + + impl std::error::Error for TryRecvError {} } +use self::error::*; + /// Data shared between senders and receivers struct Shared<T> { /// slots in the channel @@ -273,7 +315,7 @@ struct Tail { closed: bool, /// Receivers waiting for a value - waiters: LinkedList<Waiter>, + waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, } /// Slot in the buffer @@ -373,8 +415,8 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe /// [`Receiver`]: crate::sync::broadcast::Receiver /// [`recv`]: crate::sync::broadcast::Receiver::recv -/// [`SendError`]: crate::sync::broadcast::SendError -/// [`RecvError`]: crate::sync::broadcast::RecvError +/// [`SendError`]: crate::sync::broadcast::error::SendError +/// [`RecvError`]: crate::sync::broadcast::error::RecvError /// /// # Examples /// @@ -400,7 +442,7 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// tx.send(20).unwrap(); /// } /// ``` -pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { +pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { assert!(capacity > 0, "capacity is empty"); assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); @@ -433,7 +475,6 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) { let rx = Receiver { shared: shared.clone(), next: 0, - waiter: None, }; let tx = Sender { shared }; @@ -528,23 +569,7 @@ impl<T> Sender<T> { /// ``` pub fn subscribe(&self) -> Receiver<T> { let shared = self.shared.clone(); - - let mut tail = shared.tail.lock().unwrap(); - - if tail.rx_cnt == MAX_RECEIVERS { - panic!("max receivers"); - } - - tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); - let next = tail.pos; - - drop(tail); - - Receiver { - shared, - next, - waiter: None, - } + new_receiver(shared) } /// Returns the number of active receivers @@ -584,12 +609,12 @@ impl<T> Sender<T> { /// } /// ``` pub fn receiver_count(&self) -> usize { - let tail = self.shared.tail.lock().unwrap(); + let tail = self.shared.tail.lock(); tail.rx_cnt } fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> { - let mut tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock(); if tail.rx_cnt == 0 { return Err(SendError(value)); @@ -634,6 +659,22 @@ impl<T> Sender<T> { } } +fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> { + let mut tail = shared.tail.lock(); + + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + + let next = tail.pos; + + drop(tail); + + Receiver { shared, next } +} + impl Tail { fn notify_rx(&mut self) { while let Some(mut waiter) = self.waiters.pop_back() { @@ -695,7 +736,7 @@ impl<T> Receiver<T> { // the slot lock. drop(slot); - let mut tail = self.shared.tail.lock().unwrap(); + let mut tail = self.shared.tail.lock(); // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); @@ -784,106 +825,7 @@ impl<T> Receiver<T> { } } -impl<T> Receiver<T> -where - T: Clone, -{ - /// Attempts to return a pending value on this receiver without awaiting. - /// - /// This is useful for a flavor of "optimistic check" before deciding to - /// await on a receiver. - /// - /// Compared with [`recv`], this function has three failure cases instead of one - /// (one for closed, one for an empty buffer, one for a lagging receiver). - /// - /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have - /// dropped, indicating that no further values can be sent on the channel. - /// - /// If the [`Receiver`] handle falls behind, once the channel is full, newly - /// sent values will overwrite old values. At this point, a call to [`recv`] - /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s - /// internal cursor is updated to point to the oldest value still held by - /// the channel. A subsequent call to [`try_recv`] will return this value - /// **unless** it has been since overwritten. If there are no values to - /// receive, `Err(TryRecvError::Empty)` is returned. - /// - /// [`recv`]: crate::sync::broadcast::Receiver::recv - /// [`Receiver`]: crate::sync::broadcast::Receiver - /// - /// # Examples - /// - /// ``` - /// use tokio::sync::broadcast; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, mut rx) = broadcast::channel(16); - /// - /// assert!(rx.try_recv().is_err()); - /// - /// tx.send(10).unwrap(); - /// - /// let value = rx.try_recv().unwrap(); - /// assert_eq!(10, value); - /// } - /// ``` - pub fn try_recv(&mut self) -> Result<T, TryRecvError> { - let guard = self.recv_ref(None)?; - guard.clone_value().ok_or(TryRecvError::Closed) - } - - #[doc(hidden)] - #[deprecated(since = "0.2.21", note = "use async fn recv()")] - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> { - use Poll::{Pending, Ready}; - - // The borrow checker prohibits calling `self.poll_ref` while passing in - // a mutable ref to a field (as it should). To work around this, - // `waiter` is first *removed* from `self` then `poll_recv` is called. - // - // However, for safety, we must ensure that `waiter` is **not** dropped. - // It could be contained in the intrusive linked list. The `Receiver` - // drop implementation handles cleanup. - // - // The guard pattern is used to ensure that, on return, even due to - // panic, the waiter node is replaced on `self`. - - struct Guard<'a, T> { - waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>, - receiver: &'a mut Receiver<T>, - } - - impl<'a, T> Drop for Guard<'a, T> { - fn drop(&mut self) { - self.receiver.waiter = self.waiter.take(); - } - } - - let waiter = self.waiter.take().or_else(|| { - Some(Box::pin(UnsafeCell::new(Waiter { - queued: false, - waker: None, - pointers: linked_list::Pointers::new(), - _p: PhantomPinned, - }))) - }); - - let guard = Guard { - waiter, - receiver: self, - }; - let res = guard - .receiver - .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker()))); - - match res { - Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)), - Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)), - Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))), - Err(TryRecvError::Empty) => Pending, - } - } - +impl<T: Clone> Receiver<T> { /// Receives the next value for this receiver. /// /// Each [`Receiver`] handle will receive a clone of all values sent @@ -948,54 +890,103 @@ where /// assert_eq!(20, rx.recv().await.unwrap()); /// assert_eq!(30, rx.recv().await.unwrap()); /// } + /// ``` pub async fn recv(&mut self) -> Result<T, RecvError> { let fut = Recv::<_, T>::new(Borrow(self)); fut.await } -} -#[cfg(feature = "stream")] -#[doc(hidden)] -#[deprecated(since = "0.2.21", note = "use `into_stream()`")] -impl<T> crate::stream::Stream for Receiver<T> -where - T: Clone, -{ - type Item = Result<T, RecvError>; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Result<T, RecvError>>> { - #[allow(deprecated)] - self.poll_recv(cx).map(|v| match v { - Ok(v) => Some(Ok(v)), - lag @ Err(RecvError::Lagged(_)) => Some(lag), - Err(RecvError::Closed) => None, - }) + /// Attempts to return a pending value on this receiver without awaiting. + /// + /// This is useful for a flavor of "optimistic check" before deciding to + /// await on a receiver. + /// + /// Compared with [`recv`], this function has three failure cases instead of two + /// (one for closed, one for an empty buffer, one for a lagging receiver). + /// + /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have + /// dropped, indicating that no further values can be sent on the channel. + /// + /// If the [`Receiver`] handle falls behind, once the channel is full, newly + /// sent values will overwrite old values. At this point, a call to [`recv`] + /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s + /// internal cursor is updated to point to the oldest value still held by + /// the channel. A subsequent call to [`try_recv`] will return this value + /// **unless** it has been since overwritten. If there are no values to + /// receive, `Err(TryRecvError::Empty)` is returned. + /// + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = broadcast::channel(16); + /// + /// assert!(rx.try_recv().is_err()); + /// + /// tx.send(10).unwrap(); + /// + /// let value = rx.try_recv().unwrap(); + /// assert_eq!(10, value); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + let guard = self.recv_ref(None)?; + guard.clone_value().ok_or(TryRecvError::Closed) + } + + /// Convert the receiver into a `Stream`. + /// + /// The conversion allows using `Receiver` with APIs that require stream + /// values. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::StreamExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = broadcast::channel(128); + /// + /// tokio::spawn(async move { + /// for i in 0..10_i32 { + /// tx.send(i).unwrap(); + /// } + /// }); + /// + /// // Streams must be pinned to iterate. + /// tokio::pin! { + /// let stream = rx + /// .into_stream() + /// .filter(Result::is_ok) + /// .map(Result::unwrap) + /// .filter(|v| v % 2 == 0) + /// .map(|v| v + 1); + /// } + /// + /// while let Some(i) = stream.next().await { + /// println!("{}", i); + /// } + /// } + /// ``` + #[cfg(feature = "stream")] + #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { + Recv::new(Borrow(self)) } } impl<T> Drop for Receiver<T> { fn drop(&mut self) { - let mut tail = self.shared.tail.lock().unwrap(); - - if let Some(waiter) = &self.waiter { - // safety: tail lock is held - let queued = waiter.with(|ptr| unsafe { (*ptr).queued }); - - if queued { - // Remove the node - // - // safety: tail lock is held and the wait node is verified to be in - // the list. - unsafe { - waiter.with_mut(|ptr| { - tail.waiters.remove((&mut *ptr).into()); - }); - } - } - } + let mut tail = self.shared.tail.lock(); tail.rx_cnt -= 1; let until = tail.pos; @@ -1070,48 +1061,6 @@ where cfg_stream! { use futures_core::Stream; - impl<T: Clone> Receiver<T> { - /// Convert the receiver into a `Stream`. - /// - /// The conversion allows using `Receiver` with APIs that require stream - /// values. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::StreamExt; - /// use tokio::sync::broadcast; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = broadcast::channel(128); - /// - /// tokio::spawn(async move { - /// for i in 0..10_i32 { - /// tx.send(i).unwrap(); - /// } - /// }); - /// - /// // Streams must be pinned to iterate. - /// tokio::pin! { - /// let stream = rx - /// .into_stream() - /// .filter(Result::is_ok) - /// .map(Result::unwrap) - /// .filter(|v| v % 2 == 0) - /// .map(|v| v + 1); - /// } - /// - /// while let Some(i) = stream.next().await { - /// println!("{}", i); - /// } - /// } - /// ``` - pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { - Recv::new(Borrow(self)) - } - } - impl<R, T: Clone> Stream for Recv<R, T> where R: AsMut<Receiver<T>>, @@ -1141,7 +1090,7 @@ where fn drop(&mut self) { // Acquire the tail lock. This is required for safety before accessing // the waiter node. - let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap(); + let mut tail = self.receiver.as_mut().shared.tail.lock(); // safety: tail lock is held let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); @@ -1211,27 +1160,4 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } } -impl fmt::Display for RecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - RecvError::Closed => write!(f, "channel closed"), - RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), - } - } -} - -impl std::error::Error for RecvError {} - -impl fmt::Display for TryRecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryRecvError::Empty => write!(f, "channel empty"), - TryRecvError::Closed => write!(f, "channel closed"), - TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt), - } - } -} - -impl std::error::Error for TryRecvError {} - fn is_unpin<T: Unpin>() {} diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs deleted file mode 100644 index d60d8e0..0000000 --- a/src/sync/cancellation_token.rs +++ /dev/null @@ -1,861 +0,0 @@ -//! An asynchronously awaitable `CancellationToken`. -//! The token allows to signal a cancellation request to one or more tasks. - -use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::Mutex; -use crate::util::intrusive_double_linked_list::{LinkedList, ListNode}; - -use core::future::Future; -use core::pin::Pin; -use core::ptr::NonNull; -use core::sync::atomic::Ordering; -use core::task::{Context, Poll, Waker}; - -/// A token which can be used to signal a cancellation request to one or more -/// tasks. -/// -/// Tasks can call [`CancellationToken::cancelled()`] in order to -/// obtain a Future which will be resolved when cancellation is requested. -/// -/// Cancellation can be requested through the [`CancellationToken::cancel`] method. -/// -/// # Examples -/// -/// ```ignore -/// use tokio::select; -/// use tokio::scope::CancellationToken; -/// -/// #[tokio::main] -/// async fn main() { -/// let token = CancellationToken::new(); -/// let cloned_token = token.clone(); -/// -/// let join_handle = tokio::spawn(async move { -/// // Wait for either cancellation or a very long time -/// select! { -/// _ = cloned_token.cancelled() => { -/// // The token was cancelled -/// 5 -/// } -/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { -/// 99 -/// } -/// } -/// }); -/// -/// tokio::spawn(async move { -/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; -/// token.cancel(); -/// }); -/// -/// assert_eq!(5, join_handle.await.unwrap()); -/// } -/// ``` -pub struct CancellationToken { - inner: NonNull<CancellationTokenState>, -} - -// Safety: The CancellationToken is thread-safe and can be moved between threads, -// since all methods are internally synchronized. -unsafe impl Send for CancellationToken {} -unsafe impl Sync for CancellationToken {} - -/// A Future that is resolved once the corresponding [`CancellationToken`] -/// was cancelled -#[must_use = "futures do nothing unless polled"] -pub struct WaitForCancellationFuture<'a> { - /// The CancellationToken that is associated with this WaitForCancellationFuture - cancellation_token: Option<&'a CancellationToken>, - /// Node for waiting at the cancellation_token - wait_node: ListNode<WaitQueueEntry>, - /// Whether this future was registered at the token yet as a waiter - is_registered: bool, -} - -// Safety: Futures can be sent between threads as long as the underlying -// cancellation_token is thread-safe (Sync), -// which allows to poll/register/unregister from a different thread. -unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} - -// ===== impl CancellationToken ===== - -impl core::fmt::Debug for CancellationToken { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("CancellationToken") - .field("is_cancelled", &self.is_cancelled()) - .finish() - } -} - -impl Clone for CancellationToken { - fn clone(&self) -> Self { - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - let inner = self.state(); - - // Tokens are cloned by increasing their refcount - let current_state = inner.snapshot(); - inner.increment_refcount(current_state); - - CancellationToken { inner: self.inner } - } -} - -impl Drop for CancellationToken { - fn drop(&mut self) { - let token_state_pointer = self.inner; - - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - let inner = unsafe { &mut *self.inner.as_ptr() }; - - let mut current_state = inner.snapshot(); - - // We need to safe the parent, since the state might be released by the - // next call - let parent = inner.parent; - - // Drop our own refcount - current_state = inner.decrement_refcount(current_state); - - // If this was the last reference, unregister from the parent - if current_state.refcount == 0 { - if let Some(mut parent) = parent { - // Safety: Since we still retain a reference on the parent, it must be valid. - let parent = unsafe { parent.as_mut() }; - parent.unregister_child(token_state_pointer, current_state); - } - } - } -} - -impl CancellationToken { - /// Creates a new CancellationToken in the non-cancelled state. - pub fn new() -> CancellationToken { - let state = Box::new(CancellationTokenState::new( - None, - StateSnapshot { - cancel_state: CancellationState::NotCancelled, - has_parent_ref: false, - refcount: 1, - }, - )); - - // Safety: We just created the Box. The pointer is guaranteed to be - // not null - CancellationToken { - inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) }, - } - } - - /// Returns a reference to the utilized `CancellationTokenState`. - fn state(&self) -> &CancellationTokenState { - // Safety: The state inside a `CancellationToken` is always valid, since - // is reference counted - unsafe { &*self.inner.as_ptr() } - } - - /// Creates a `CancellationToken` which will get cancelled whenever the - /// current token gets cancelled. - /// - /// If the current token is already cancelled, the child token will get - /// returned in cancelled state. - /// - /// # Examples - /// - /// ```ignore - /// use tokio::select; - /// use tokio::scope::CancellationToken; - /// - /// #[tokio::main] - /// async fn main() { - /// let token = CancellationToken::new(); - /// let child_token = token.child_token(); - /// - /// let join_handle = tokio::spawn(async move { - /// // Wait for either cancellation or a very long time - /// select! { - /// _ = child_token.cancelled() => { - /// // The token was cancelled - /// 5 - /// } - /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => { - /// 99 - /// } - /// } - /// }); - /// - /// tokio::spawn(async move { - /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await; - /// token.cancel(); - /// }); - /// - /// assert_eq!(5, join_handle.await.unwrap()); - /// } - /// ``` - pub fn child_token(&self) -> CancellationToken { - let inner = self.state(); - - // Increment the refcount of this token. It will be referenced by the - // child, independent of whether the child is immediately cancelled or - // not. - let _current_state = inner.increment_refcount(inner.snapshot()); - - let mut unpacked_child_state = StateSnapshot { - has_parent_ref: true, - refcount: 1, - cancel_state: CancellationState::NotCancelled, - }; - let mut child_token_state = Box::new(CancellationTokenState::new( - Some(self.inner), - unpacked_child_state, - )); - - { - let mut guard = inner.synchronized.lock().unwrap(); - if guard.is_cancelled { - // This task was already cancelled. In this case we should not - // insert the child into the list, since it would never get removed - // from the list. - (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true; - unpacked_child_state.cancel_state = CancellationState::Cancelled; - // Since it's not in the list, the parent doesn't need to retain - // a reference to it. - unpacked_child_state.has_parent_ref = false; - child_token_state - .state - .store(unpacked_child_state.pack(), Ordering::SeqCst); - } else { - if let Some(mut first_child) = guard.first_child { - child_token_state.from_parent.next_peer = Some(first_child); - // Safety: We manipulate other child task inside the Mutex - // and retain a parent reference on it. The child token can't - // get invalidated while the Mutex is held. - unsafe { - first_child.as_mut().from_parent.prev_peer = - Some((&mut *child_token_state).into()) - }; - } - guard.first_child = Some((&mut *child_token_state).into()); - } - }; - - let child_token_ptr = Box::into_raw(child_token_state); - // Safety: We just created the pointer from a `Box` - CancellationToken { - inner: unsafe { NonNull::new_unchecked(child_token_ptr) }, - } - } - - /// Cancel the [`CancellationToken`] and all child tokens which had been - /// derived from it. - /// - /// This will wake up all tasks which are waiting for cancellation. - pub fn cancel(&self) { - self.state().cancel(); - } - - /// Returns `true` if the `CancellationToken` had been cancelled - pub fn is_cancelled(&self) -> bool { - self.state().is_cancelled() - } - - /// Returns a `Future` that gets fulfilled when cancellation is requested. - pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { - WaitForCancellationFuture { - cancellation_token: Some(self), - wait_node: ListNode::new(WaitQueueEntry::new()), - is_registered: false, - } - } - - unsafe fn register( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - self.state().register(wait_node, cx) - } - - fn check_for_cancellation( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - self.state().check_for_cancellation(wait_node, cx) - } - - fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) { - self.state().unregister(wait_node) - } -} - -// ===== impl WaitForCancellationFuture ===== - -impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("WaitForCancellationFuture").finish() - } -} - -impl<'a> Future for WaitForCancellationFuture<'a> { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - // Safety: We do not move anything out of `WaitForCancellationFuture` - let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; - - let cancellation_token = mut_self - .cancellation_token - .expect("polled WaitForCancellationFuture after completion"); - - let poll_res = if !mut_self.is_registered { - // Safety: The `ListNode` is pinned through the Future, - // and we will unregister it in `WaitForCancellationFuture::drop` - // before the Future is dropped and the memory reference is invalidated. - unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) } - } else { - cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx) - }; - - if let Poll::Ready(()) = poll_res { - // The cancellation_token was signalled - mut_self.cancellation_token = None; - // A signalled Token means the Waker won't be enqueued anymore - mut_self.is_registered = false; - mut_self.wait_node.task = None; - } else { - // This `Future` and its stored `Waker` stay registered at the - // `CancellationToken` - mut_self.is_registered = true; - } - - poll_res - } -} - -impl<'a> Drop for WaitForCancellationFuture<'a> { - fn drop(&mut self) { - // If this WaitForCancellationFuture has been polled and it was added to the - // wait queue at the cancellation_token, it must be removed before dropping. - // Otherwise the cancellation_token would access invalid memory. - if let Some(token) = self.cancellation_token { - if self.is_registered { - token.unregister(&mut self.wait_node); - } - } - } -} - -/// Tracks how the future had interacted with the [`CancellationToken`] -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -enum PollState { - /// The task has never interacted with the [`CancellationToken`]. - New, - /// The task was added to the wait queue at the [`CancellationToken`]. - Waiting, - /// The task has been polled to completion. - Done, -} - -/// Tracks the WaitForCancellationFuture waiting state. -/// Access to this struct is synchronized through the mutex in the CancellationToken. -struct WaitQueueEntry { - /// The task handle of the waiting task - task: Option<Waker>, - // Current polling state. This state is only updated inside the Mutex of - // the CancellationToken. - state: PollState, -} - -impl WaitQueueEntry { - /// Creates a new WaitQueueEntry - fn new() -> WaitQueueEntry { - WaitQueueEntry { - task: None, - state: PollState::New, - } - } -} - -struct SynchronizedState { - waiters: LinkedList<WaitQueueEntry>, - first_child: Option<NonNull<CancellationTokenState>>, - is_cancelled: bool, -} - -impl SynchronizedState { - fn new() -> Self { - Self { - waiters: LinkedList::new(), - first_child: None, - is_cancelled: false, - } - } -} - -/// Information embedded in child tokens which is synchronized through the Mutex -/// in their parent. -struct SynchronizedThroughParent { - next_peer: Option<NonNull<CancellationTokenState>>, - prev_peer: Option<NonNull<CancellationTokenState>>, -} - -/// Possible states of a `CancellationToken` -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CancellationState { - NotCancelled = 0, - Cancelling = 1, - Cancelled = 2, -} - -impl CancellationState { - fn pack(self) -> usize { - self as usize - } - - fn unpack(value: usize) -> Self { - match value { - 0 => CancellationState::NotCancelled, - 1 => CancellationState::Cancelling, - 2 => CancellationState::Cancelled, - _ => unreachable!("Invalid value"), - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -struct StateSnapshot { - /// The amount of references to this particular CancellationToken. - /// `CancellationToken` structs hold these references to a `CancellationTokenState`. - /// Also the state is referenced by the state of each child. - refcount: usize, - /// Whether the state is still referenced by it's parent and can therefore - /// not be freed. - has_parent_ref: bool, - /// Whether the token is cancelled - cancel_state: CancellationState, -} - -impl StateSnapshot { - /// Packs the snapshot into a `usize` - fn pack(self) -> usize { - self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack() - } - - /// Unpacks the snapshot from a `usize` - fn unpack(value: usize) -> Self { - let refcount = value >> 3; - let has_parent_ref = value & 4 != 0; - let cancel_state = CancellationState::unpack(value & 0x03); - - StateSnapshot { - refcount, - has_parent_ref, - cancel_state, - } - } - - /// Whether this `CancellationTokenState` is still referenced by any - /// `CancellationToken`. - fn has_refs(&self) -> bool { - self.refcount != 0 || self.has_parent_ref - } -} - -/// The maximum permitted amount of references to a CancellationToken. This -/// is derived from the intent to never use more than 32bit in the `Snapshot`. -const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3; - -/// Internal state of the `CancellationToken` pair above -struct CancellationTokenState { - state: AtomicUsize, - parent: Option<NonNull<CancellationTokenState>>, - from_parent: SynchronizedThroughParent, - synchronized: Mutex<SynchronizedState>, -} - -impl CancellationTokenState { - fn new( - parent: Option<NonNull<CancellationTokenState>>, - state: StateSnapshot, - ) -> CancellationTokenState { - CancellationTokenState { - parent, - from_parent: SynchronizedThroughParent { - prev_peer: None, - next_peer: None, - }, - state: AtomicUsize::new(state.pack()), - synchronized: Mutex::new(SynchronizedState::new()), - } - } - - /// Returns a snapshot of the current atomic state of the token - fn snapshot(&self) -> StateSnapshot { - StateSnapshot::unpack(self.state.load(Ordering::SeqCst)) - } - - fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot - where - F: Fn(StateSnapshot) -> StateSnapshot, - { - let mut current_packed_state = current_state.pack(); - loop { - let next_state = func(current_state); - match self.state.compare_exchange( - current_packed_state, - next_state.pack(), - Ordering::SeqCst, - Ordering::SeqCst, - ) { - Ok(_) => { - return next_state; - } - Err(actual) => { - current_packed_state = actual; - current_state = StateSnapshot::unpack(actual); - } - } - } - } - - fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { - self.atomic_update_state(current_state, |mut state: StateSnapshot| { - if state.refcount >= MAX_REFS as usize { - eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded"); - std::process::abort(); - } - state.refcount += 1; - state - }) - } - - fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot { - let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { - state.refcount -= 1; - state - }); - - // Drop the State if it is not referenced anymore - if !current_state.has_refs() { - // Safety: `CancellationTokenState` is always stored in refcounted - // Boxes - let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; - } - - current_state - } - - fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot { - let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| { - state.has_parent_ref = false; - state - }); - - // Drop the State if it is not referenced anymore - if !current_state.has_refs() { - // Safety: `CancellationTokenState` is always stored in refcounted - // Boxes - let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) }; - } - - current_state - } - - /// Unregisters a child from the parent token. - /// The child tokens state is not exactly known at this point in time. - /// If the parent token is cancelled, the child token gets removed from the - /// parents list, and might therefore already have been freed. If the parent - /// token is not cancelled, the child token is still valid. - fn unregister_child( - &mut self, - mut child_state: NonNull<CancellationTokenState>, - current_child_state: StateSnapshot, - ) { - let removed_child = { - // Remove the child toke from the parents linked list - let mut guard = self.synchronized.lock().unwrap(); - if !guard.is_cancelled { - // Safety: Since the token was not cancelled, the child must - // still be in the list and valid. - let mut child_state = unsafe { child_state.as_mut() }; - debug_assert!(child_state.snapshot().has_parent_ref); - - if guard.first_child == Some(child_state.into()) { - guard.first_child = child_state.from_parent.next_peer; - } - // Safety: If peers wouldn't be valid anymore, they would try - // to remove themselves from the list. This would require locking - // the Mutex that we currently own. - unsafe { - if let Some(mut prev_peer) = child_state.from_parent.prev_peer { - prev_peer.as_mut().from_parent.next_peer = - child_state.from_parent.next_peer; - } - if let Some(mut next_peer) = child_state.from_parent.next_peer { - next_peer.as_mut().from_parent.prev_peer = - child_state.from_parent.prev_peer; - } - } - child_state.from_parent.prev_peer = None; - child_state.from_parent.next_peer = None; - - // The child is no longer referenced by the parent, since we were able - // to remove its reference from the parents list. - true - } else { - // Do not touch the linked list anymore. If the parent is cancelled - // it will move all childs outside of the Mutex and manipulate - // the pointers there. Manipulating the pointers here too could - // lead to races. Therefore leave them just as as and let the - // parent deal with it. The parent will make sure to retain a - // reference to this state as long as it manipulates the list - // pointers. Therefore the pointers are not dangling. - false - } - }; - - if removed_child { - // If the token removed itself from the parents list, it can reset - // the the parent ref status. If it is isn't able to do so, because the - // parent removed it from the list, there is no need to do this. - // The parent ref acts as as another reference count. Therefore - // removing this reference can free the object. - // Safety: The token was in the list. This means the parent wasn't - // cancelled before, and the token must still be alive. - unsafe { child_state.as_mut().remove_parent_ref(current_child_state) }; - } - - // Decrement the refcount on the parent and free it if necessary - self.decrement_refcount(self.snapshot()); - } - - fn cancel(&self) { - // Move the state of the CancellationToken from `NotCancelled` to `Cancelling` - let mut current_state = self.snapshot(); - - let state_after_cancellation = loop { - if current_state.cancel_state != CancellationState::NotCancelled { - // Another task already initiated the cancellation - return; - } - - let mut next_state = current_state; - next_state.cancel_state = CancellationState::Cancelling; - match self.state.compare_exchange( - current_state.pack(), - next_state.pack(), - Ordering::SeqCst, - Ordering::SeqCst, - ) { - Ok(_) => break next_state, - Err(actual) => current_state = StateSnapshot::unpack(actual), - } - }; - - // This task cancelled the token - - // Take the task list out of the Token - // We do not want to cancel child token inside this lock. If one of the - // child tasks would have additional child tokens, we would recursively - // take locks. - - // Doing this action has an impact if the child token is dropped concurrently: - // It will try to deregister itself from the parent task, but can not find - // itself in the task list anymore. Therefore it needs to assume the parent - // has extracted the list and will process it. It may not modify the list. - // This is OK from a memory safety perspective, since the parent still - // retains a reference to the child task until it finished iterating over - // it. - - let mut first_child = { - let mut guard = self.synchronized.lock().unwrap(); - // Save the cancellation also inside the Mutex - // This allows child tokens which want to detach themselves to detect - // that this is no longer required since the parent cleared the list. - guard.is_cancelled = true; - - // Wakeup all waiters - // This happens inside the lock to make cancellation reliable - // If we would access waiters outside of the lock, the pointers - // may no longer be valid. - // Typically this shouldn't be an issue, since waking a task should - // only move it from the blocked into the ready state and not have - // further side effects. - - // Use a reverse iterator, so that the oldest waiter gets - // scheduled first - guard.waiters.reverse_drain(|waiter| { - // We are not allowed to move the `Waker` out of the list node. - // The `Future` relies on the fact that the old `Waker` stays there - // as long as the `Future` has not completed in order to perform - // the `will_wake()` check. - // Therefore `wake_by_ref` is used instead of `wake()` - if let Some(handle) = &mut waiter.task { - handle.wake_by_ref(); - } - // Mark the waiter to have been removed from the list. - waiter.state = PollState::Done; - }); - - guard.first_child.take() - }; - - while let Some(mut child) = first_child { - // Safety: We know this is a valid pointer since it is in our child pointer - // list. It can't have been freed in between, since we retain a a reference - // to each child. - let mut_child = unsafe { child.as_mut() }; - - // Get the next child and clean up list pointers - first_child = mut_child.from_parent.next_peer; - mut_child.from_parent.prev_peer = None; - mut_child.from_parent.next_peer = None; - - // Cancel the child task - mut_child.cancel(); - - // Drop the parent reference. This `CancellationToken` is not interested - // in interacting with the child anymore. - // This is ONLY allowed once we promised not to touch the state anymore - // after this interaction. - mut_child.remove_parent_ref(mut_child.snapshot()); - } - - // The cancellation has completed - // At this point in time tasks which registered a wait node can be sure - // that this wait node already had been dequeued from the list without - // needing to inspect the list. - self.atomic_update_state(state_after_cancellation, |mut state| { - state.cancel_state = CancellationState::Cancelled; - state - }); - } - - /// Returns `true` if the `CancellationToken` had been cancelled - fn is_cancelled(&self) -> bool { - let current_state = self.snapshot(); - current_state.cancel_state != CancellationState::NotCancelled - } - - /// Registers a waiting task at the `CancellationToken`. - /// Safety: This method is only safe as long as the waiting waiting task - /// will properly unregister the wait node before it gets moved. - unsafe fn register( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - debug_assert_eq!(PollState::New, wait_node.state); - let current_state = self.snapshot(); - - // Perform an optimistic cancellation check before. This is not strictly - // necessary since we also check for cancellation in the Mutex, but - // reduces the necessary work to be performed for tasks which already - // had been cancelled. - if current_state.cancel_state != CancellationState::NotCancelled { - return Poll::Ready(()); - } - - // So far the token is not cancelled. However it could be cancelld before - // we get the chance to store the `Waker`. Therfore we need to check - // for cancellation again inside the mutex. - let mut guard = self.synchronized.lock().unwrap(); - if guard.is_cancelled { - // Cancellation was signalled - wait_node.state = PollState::Done; - Poll::Ready(()) - } else { - // Added the task to the wait queue - wait_node.task = Some(cx.waker().clone()); - wait_node.state = PollState::Waiting; - guard.waiters.add_front(wait_node); - Poll::Pending - } - } - - fn check_for_cancellation( - &self, - wait_node: &mut ListNode<WaitQueueEntry>, - cx: &mut Context<'_>, - ) -> Poll<()> { - debug_assert!( - wait_node.task.is_some(), - "Method can only be called after task had been registered" - ); - - let current_state = self.snapshot(); - - if current_state.cancel_state != CancellationState::NotCancelled { - // If the cancellation had been fully completed we know that our `Waker` - // is no longer registered at the `CancellationToken`. - // Otherwise the cancel call may or may not yet have iterated - // through the waiters list and removed the wait nodes. - // If it hasn't yet, we need to remove it. Otherwise an attempt to - // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture` - // getting dropped before the cancellation had interacted with it. - if current_state.cancel_state != CancellationState::Cancelled { - self.unregister(wait_node); - } - Poll::Ready(()) - } else { - // Check if we need to swap the `Waker`. This will make the check more - // expensive, since the `Waker` is synchronized through the Mutex. - // If we don't need to perform a `Waker` update, an atomic check for - // cancellation is sufficient. - let need_waker_update = wait_node - .task - .as_ref() - .map(|waker| waker.will_wake(cx.waker())) - .unwrap_or(true); - - if need_waker_update { - let guard = self.synchronized.lock().unwrap(); - if guard.is_cancelled { - // Cancellation was signalled. Since this cancellation signal - // is set inside the Mutex, the old waiter must already have - // been removed from the waiting list - debug_assert_eq!(PollState::Done, wait_node.state); - wait_node.task = None; - Poll::Ready(()) - } else { - // The WaitForCancellationFuture is already in the queue. - // The CancellationToken can't have been cancelled, - // since this would change the is_cancelled flag inside the mutex. - // Therefore we just have to update the Waker. A follow-up - // cancellation will always use the new waker. - wait_node.task = Some(cx.waker().clone()); - Poll::Pending - } - } else { - // Do nothing. If the token gets cancelled, this task will get - // woken again and can fetch the cancellation. - Poll::Pending - } - } - } - - fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) { - debug_assert!( - wait_node.task.is_some(), - "waiter can not be active without task" - ); - - let mut guard = self.synchronized.lock().unwrap(); - // WaitForCancellationFuture only needs to get removed if it has been added to - // the wait queue of the CancellationToken. - // This has happened in the PollState::Waiting case. - if let PollState::Waiting = wait_node.state { - // Safety: Due to the state, we know that the node must be part - // of the waiter list - if !unsafe { guard.waiters.remove(wait_node) } { - // Panic if the address isn't found. This can only happen if the contract was - // violated, e.g. the WaitQueueEntry got moved after the initial poll. - panic!("Future could not be removed from wait queue"); - } - wait_node.state = PollState::Done; - } - wait_node.task = None; - } -} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 3d96106..57ae277 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -20,7 +20,7 @@ //! few flavors of channels provided by Tokio. Each channel flavor supports //! different message passing patterns. When a channel supports multiple //! producers, many separate tasks may **send** messages. When a channel -//! supports muliple consumers, many different separate tasks may **receive** +//! supports multiple consumers, many different separate tasks may **receive** //! messages. //! //! Tokio provides many different channel flavors as different message passing @@ -106,7 +106,7 @@ //! //! #[tokio::main] //! async fn main() { -//! let (mut tx, mut rx) = mpsc::channel(100); +//! let (tx, mut rx) = mpsc::channel(100); //! //! tokio::spawn(async move { //! for i in 0..10 { @@ -150,7 +150,7 @@ //! for _ in 0..10 { //! // Each task needs its own `tx` handle. This is done by cloning the //! // original handle. -//! let mut tx = tx.clone(); +//! let tx = tx.clone(); //! //! tokio::spawn(async move { //! tx.send(&b"data to write"[..]).await.unwrap(); @@ -213,7 +213,7 @@ //! //! // Spawn tasks that will send the increment command. //! for _ in 0..10 { -//! let mut cmd_tx = cmd_tx.clone(); +//! let cmd_tx = cmd_tx.clone(); //! //! join_handles.push(tokio::spawn(async move { //! let (resp_tx, resp_rx) = oneshot::channel(); @@ -322,7 +322,7 @@ //! tokio::spawn(async move { //! loop { //! // Wait 10 seconds between checks -//! time::delay_for(Duration::from_secs(10)).await; +//! time::sleep(Duration::from_secs(10)).await; //! //! // Load the configuration file //! let new_config = Config::load_from_file().await.unwrap(); @@ -330,7 +330,7 @@ //! // If the configuration changed, send the new config value //! // on the watch channel. //! if new_config != config { -//! tx.broadcast(new_config.clone()).unwrap(); +//! tx.send(new_config.clone()).unwrap(); //! config = new_config; //! } //! } @@ -355,17 +355,15 @@ //! let op = my_async_operation(); //! tokio::pin!(op); //! -//! // Receive the **initial** configuration value. As this is the -//! // first time the config is received from the watch, it will -//! // always complete immediatedly. -//! let mut conf = rx.recv().await.unwrap(); +//! // Get the initial config value +//! let mut conf = rx.borrow().clone(); //! //! let mut op_start = Instant::now(); -//! let mut delay = time::delay_until(op_start + conf.timeout); +//! let mut sleep = time::sleep_until(op_start + conf.timeout); //! //! loop { //! tokio::select! { -//! _ = &mut delay => { +//! _ = &mut sleep => { //! // The operation elapsed. Restart it //! op.set(my_async_operation()); //! @@ -373,14 +371,14 @@ //! op_start = Instant::now(); //! //! // Restart the timeout -//! delay = time::delay_until(op_start + conf.timeout); +//! sleep = time::sleep_until(op_start + conf.timeout); //! } -//! new_conf = rx.recv() => { -//! conf = new_conf.unwrap(); +//! _ = rx.changed() => { +//! conf = rx.borrow().clone(); //! //! // The configuration has been updated. Update the -//! // `delay` using the new `timeout` value. -//! delay.reset(op_start + conf.timeout); +//! // `sleep` using the new `timeout` value. +//! sleep.reset(op_start + conf.timeout); //! } //! _ = &mut op => { //! // The operation completed! @@ -399,14 +397,14 @@ //! } //! ``` //! -//! [`watch` channel]: crate::sync::watch -//! [`broadcast` channel]: crate::sync::broadcast +//! [`watch` channel]: mod@crate::sync::watch +//! [`broadcast` channel]: mod@crate::sync::broadcast //! //! # State synchronization //! //! The remaining synchronization primitives focus on synchronizing state. //! These are asynchronous equivalents to versions provided by `std`. They -//! operate in a similar way as their `std` counterparts parts but will wait +//! operate in a similar way as their `std` counterparts but will wait //! asynchronously instead of blocking the thread. //! //! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to @@ -434,23 +432,17 @@ cfg_sync! { pub mod broadcast; - cfg_unstable! { - mod cancellation_token; - pub use cancellation_token::{CancellationToken, WaitForCancellationFuture}; - } - pub mod mpsc; mod mutex; pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard}; - mod notify; + pub(crate) mod notify; pub use notify::Notify; pub mod oneshot; pub(crate) mod batch_semaphore; - pub(crate) mod semaphore_ll; mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; @@ -464,20 +456,30 @@ cfg_sync! { } cfg_not_sync! { + #[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod batch_semaphore; + + cfg_fs! { + mod mutex; + pub(crate) use mutex::Mutex; + } + + #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod notify; + cfg_atomic_waker_impl! { mod task; pub(crate) use task::AtomicWaker; } #[cfg(any( - feature = "rt-core", + feature = "rt", feature = "process", feature = "signal"))] pub(crate) mod oneshot; - cfg_signal! { + cfg_signal_internal! { pub(crate) mod mpsc; - pub(crate) mod semaphore_ll; } } diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs index 7bf1619..e062f2b 100644 --- a/src/sync/mpsc/block.rs +++ b/src/sync/mpsc/block.rs @@ -1,8 +1,6 @@ -use crate::loom::{ - cell::UnsafeCell, - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; use std::mem::MaybeUninit; use std::ops; diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index afca8c5..06b3717 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ +use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError}; -use crate::sync::semaphore_ll as semaphore; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -8,6 +8,7 @@ cfg_time! { } use std::fmt; +#[cfg(any(feature = "signal", feature = "process", feature = "stream"))] use std::task::{Context, Poll}; /// Send values to the associated `Receiver`. @@ -17,20 +18,14 @@ pub struct Sender<T> { chan: chan::Tx<T, Semaphore>, } -impl<T> Clone for Sender<T> { - fn clone(&self) -> Self { - Sender { - chan: self.chan.clone(), - } - } -} - -impl<T> fmt::Debug for Sender<T> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Sender") - .field("chan", &self.chan) - .finish() - } +/// Permit to send one value into the channel. +/// +/// `Permit` values are returned by [`Sender::reserve()`] and are used to +/// guarantee channel capacity before generating a message to send. +/// +/// [`Sender::reserve()`]: Sender::reserve +pub struct Permit<'a, T> { + chan: &'a chan::Tx<T, Semaphore>, } /// Receive values from the associated `Sender`. @@ -41,16 +36,12 @@ pub struct Receiver<T> { chan: chan::Rx<T, Semaphore>, } -impl<T> fmt::Debug for Receiver<T> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Receiver") - .field("chan", &self.chan) - .finish() - } -} - -/// Creates a bounded mpsc channel for communicating between asynchronous tasks, -/// returning the sender/receiver halves. +/// Creates a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. The provided buffer capacity must be at least 1. /// /// All data sent on `Sender` will become available on `Receiver` in the same /// order as it was sent. @@ -62,6 +53,10 @@ impl<T> fmt::Debug for Receiver<T> { /// will return a `SendError`. Similarly, if `Sender` is disconnected while /// trying to `recv`, the `recv` method will return a `RecvError`. /// +/// # Panics +/// +/// Panics if the buffer capacity is 0. +/// /// # Examples /// /// ```rust @@ -69,7 +64,7 @@ impl<T> fmt::Debug for Receiver<T> { /// /// #[tokio::main] /// async fn main() { -/// let (mut tx, mut rx) = mpsc::channel(100); +/// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -117,7 +112,7 @@ impl<T> Receiver<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tokio::spawn(async move { /// tx.send("hello").await.unwrap(); @@ -135,7 +130,7 @@ impl<T> Receiver<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(100); + /// let (tx, mut rx) = mpsc::channel(100); /// /// tx.send("hello").await.unwrap(); /// tx.send("world").await.unwrap(); @@ -146,15 +141,48 @@ impl<T> Receiver<T> { /// ``` pub async fn recv(&mut self) -> Option<T> { use crate::future::poll_fn; - - poll_fn(|cx| self.poll_recv(cx)).await + poll_fn(|cx| self.chan.recv(cx)).await } - #[doc(hidden)] // TODO: document - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } + /// Blocking receive to call outside of asynchronous contexts. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples + /// + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(10); + /// + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// Runtime::new() + /// .unwrap() + /// .block_on(async move { + /// let _ = tx.send(10).await; + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) + } + /// Attempts to return a pending value on this receiver without blocking. /// /// This method will never block the caller in order to wait for data to @@ -173,12 +201,53 @@ impl<T> Receiver<T> { /// Closes the receiving half of a channel, without dropping it. /// /// This prevents any further messages from being sent on the channel while - /// still enabling the receiver to drain messages that are buffered. + /// still enabling the receiver to drain messages that are buffered. Any + /// outstanding [`Permit`] values will still be able to send messages. + /// + /// In order to guarantee no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. + /// + /// [`Permit`]: Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(20); + /// + /// tokio::spawn(async move { + /// let mut i = 0; + /// while let Ok(permit) = tx.reserve().await { + /// permit.send(i); + /// i += 1; + /// } + /// }); + /// + /// rx.close(); + /// + /// while let Some(msg) = rx.recv().await { + /// println!("got {}", msg); + /// } + /// + /// // Channel closed and no messages are lost. + /// } + /// ``` pub fn close(&mut self) { self.chan.close(); } } +impl<T> fmt::Debug for Receiver<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Receiver") + .field("chan", &self.chan) + .finish() + } +} + impl<T> Unpin for Receiver<T> {} cfg_stream! { @@ -186,7 +255,7 @@ cfg_stream! { type Item = T; fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.poll_recv(cx) + self.chan.recv(cx) } } } @@ -225,7 +294,7 @@ impl<T> Sender<T> { /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -241,18 +310,49 @@ impl<T> Sender<T> { /// } /// } /// ``` - pub async fn send(&mut self, value: T) -> Result<(), SendError<T>> { - use crate::future::poll_fn; - - if poll_fn(|cx| self.poll_ready(cx)).await.is_err() { - return Err(SendError(value)); + pub async fn send(&self, value: T) -> Result<(), SendError<T>> { + match self.reserve().await { + Ok(permit) => { + permit.send(value); + Ok(()) + } + Err(_) => Err(SendError(value)), } + } - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendError(value)), - } + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::channel::<()>(1); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await } /// Attempts to immediately send a message on this `Sender` @@ -262,9 +362,6 @@ impl<T> Sender<T> { /// with [`send`], this function has two failure cases instead of one (one for /// disconnection, one for a full buffer). /// - /// This function may be paired with [`poll_ready`] in order to wait for - /// channel capacity before trying to send a value. - /// /// # Errors /// /// If the channel capacity has been reached, i.e., the channel has `n` @@ -276,7 +373,6 @@ impl<T> Sender<T> { /// an error. The error includes the value passed to `send`. /// /// [`send`]: Sender::send - /// [`poll_ready`]: Sender::poll_ready /// [`channel`]: channel /// [`close`]: Receiver::close /// @@ -288,8 +384,8 @@ impl<T> Sender<T> { /// #[tokio::main] /// async fn main() { /// // Create a channel with buffer size 1 - /// let (mut tx1, mut rx) = mpsc::channel(1); - /// let mut tx2 = tx1.clone(); + /// let (tx1, mut rx) = mpsc::channel(1); + /// let tx2 = tx1.clone(); /// /// tokio::spawn(async move { /// tx1.send(1).await.unwrap(); @@ -317,8 +413,15 @@ impl<T> Sender<T> { /// } /// } /// ``` - pub fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { - self.chan.try_send(message)?; + pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)), + } + + // Send the message + self.chan.send(message); Ok(()) } @@ -346,11 +449,11 @@ impl<T> Sender<T> { /// /// ```rust /// use tokio::sync::mpsc; - /// use tokio::time::{delay_for, Duration}; + /// use tokio::time::{sleep, Duration}; /// /// #[tokio::main] /// async fn main() { - /// let (mut tx, mut rx) = mpsc::channel(1); + /// let (tx, mut rx) = mpsc::channel(1); /// /// tokio::spawn(async move { /// for i in 0..10 { @@ -363,117 +466,213 @@ impl<T> Sender<T> { /// /// while let Some(i) = rx.recv().await { /// println!("got = {}", i); - /// delay_for(Duration::from_millis(200)).await; + /// sleep(Duration::from_millis(200)).await; /// } /// } /// ``` #[cfg(feature = "time")] #[cfg_attr(docsrs, doc(cfg(feature = "time")))] pub async fn send_timeout( - &mut self, + &self, value: T, timeout: Duration, ) -> Result<(), SendTimeoutError<T>> { - use crate::future::poll_fn; - - match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await { + let permit = match crate::time::timeout(timeout, self.reserve()).await { Err(_) => { return Err(SendTimeoutError::Timeout(value)); } Ok(Err(_)) => { return Err(SendTimeoutError::Closed(value)); } - Ok(_) => {} - } + Ok(Ok(permit)) => permit, + }; - match self.try_send(value) { - Ok(()) => Ok(()), - Err(TrySendError::Full(_)) => unreachable!(), - Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)), - } + permit.send(value); + Ok(()) } - /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item. + /// Blocking send to call outside of asynchronous contexts. /// - /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a - /// slot becomes available. + /// # Panics /// - /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless - /// the channel has since been closed. To provide this guarantee, the channel reserves one slot - /// in the channel for the coming send. This reserved slot is not available to other `Sender` - /// instances, so you need to be careful to not end up with deadlocks by blocking after calling - /// `poll_ready` but before sending an element. + /// This function panics if called within an asynchronous execution + /// context. /// - /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you - /// can use [`disarm`](Sender::disarm) to release the reserved slot. + /// # Examples /// - /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to - /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel - /// is closed. - pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { - self.chan.poll_ready(cx).map_err(|_| ClosedError::new()) + /// ``` + /// use std::thread; + /// use tokio::runtime::Runtime; + /// use tokio::sync::mpsc; + /// + /// fn main() { + /// let (tx, mut rx) = mpsc::channel::<u8>(1); + /// + /// let sync_code = thread::spawn(move || { + /// tx.blocking_send(10).unwrap(); + /// }); + /// + /// Runtime::new().unwrap().block_on(async move { + /// assert_eq!(Some(10), rx.recv().await); + /// }); + /// sync_code.join().unwrap() + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> { + crate::future::block_on(self.send(value)) } - /// Undo a successful call to `poll_ready`. + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. /// - /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the - /// channel to make room for the coming send. `disarm` allows you to give up that slot if you - /// decide you do not wish to send an item after all. After calling `disarm`, you must call - /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again. + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close /// - /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was - /// not previously called, or did not succeed). + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42); + /// assert!(!tx.is_closed()); /// - /// # Motivation + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); /// - /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers - /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may - /// take up all the slots of the channel, and prevent active senders from getting any requests - /// through. Consider this code that forwards from one channel to another: + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } + + /// Wait for channel capacity. Once capacity to send one message is + /// available, it is reserved for the caller. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. A [`Permit`] is returned to track + /// the reserved capacity. The [`send`] function on [`Permit`] consumes the + /// reserved capacity. + /// + /// Dropping [`Permit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// if let Some(item) = ready!(rx.poll_recv(cx)) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` + pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> { + match self.chan.semaphore().0.acquire(1).await { + Ok(_) => {} + Err(_) => return Err(SendError(())), + } + + Ok(Permit { chan: &self.chan }) + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + Sender { + chan: self.chan.clone(), + } + } +} + +impl<T> fmt::Debug for Sender<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Sender") + .field("chan", &self.chan) + .finish() + } +} + +// ===== impl Permit ===== + +impl<T> Permit<'_, T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve().await.unwrap(); /// - /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then - /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do, - /// they are effectively each reducing the channel's capacity by 1. If enough of these - /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot - /// for them through `poll_ready`, and the system will deadlock. - /// - /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that - /// you have to block. We can then fix the code above by writing: - /// - /// ```rust,ignore - /// loop { - /// ready!(tx.poll_ready(cx))?; - /// let item = rx.poll_recv(cx); - /// if let Poll::Ready(Ok(_)) = item { - /// // we're going to send the item below, so don't disarm - /// } else { - /// // give up our send slot, we won't need it for a while - /// tx.disarm(); - /// } - /// if let Some(item) = ready!(item) { - /// tx.try_send(item)?; - /// } else { - /// break; - /// } + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Send a message on the permit + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); /// } /// ``` - pub fn disarm(&mut self) -> bool { - if self.chan.is_ready() { - self.chan.disarm(); - true - } else { - false + pub fn send(self, value: T) { + use std::mem; + + self.chan.send(value); + + // Avoid the drop logic + mem::forget(self); + } +} + +impl<T> Drop for Permit<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + let semaphore = self.chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); } } } + +impl<T> fmt::Debug for Permit<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Permit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs index 148ee3a..c78fb50 100644 --- a/src/sync/mpsc/chan.rs +++ b/src/sync/mpsc/chan.rs @@ -2,8 +2,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::error::{ClosedError, TryRecvError}; -use crate::sync::mpsc::{error, list}; +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::list; +use crate::sync::notify::Notify; use std::fmt; use std::process; @@ -12,21 +13,13 @@ use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; /// Channel sender -pub(crate) struct Tx<T, S: Semaphore> { +pub(crate) struct Tx<T, S> { inner: Arc<Chan<T, S>>, - permit: S::Permit, } -impl<T, S: Semaphore> fmt::Debug for Tx<T, S> -where - S::Permit: fmt::Debug, - S: fmt::Debug, -{ +impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Tx") - .field("inner", &self.inner) - .field("permit", &self.permit) - .finish() + fmt.debug_struct("Tx").field("inner", &self.inner).finish() } } @@ -35,70 +28,26 @@ pub(crate) struct Rx<T, S: Semaphore> { inner: Arc<Chan<T, S>>, } -impl<T, S: Semaphore> fmt::Debug for Rx<T, S> -where - S: fmt::Debug, -{ +impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Rx").field("inner", &self.inner).finish() } } -#[derive(Debug, Eq, PartialEq)] -pub(crate) enum TrySendError { - Closed, - Full, -} - -impl<T> From<(T, TrySendError)> for error::SendError<T> { - fn from(src: (T, TrySendError)) -> error::SendError<T> { - match src.1 { - TrySendError::Closed => error::SendError(src.0), - TrySendError::Full => unreachable!(), - } - } -} - -impl<T> From<(T, TrySendError)> for error::TrySendError<T> { - fn from(src: (T, TrySendError)) -> error::TrySendError<T> { - match src.1 { - TrySendError::Closed => error::TrySendError::Closed(src.0), - TrySendError::Full => error::TrySendError::Full(src.0), - } - } -} - pub(crate) trait Semaphore { - type Permit; - - fn new_permit() -> Self::Permit; - - /// The permit is dropped without a value being sent. In this case, the - /// permit must be returned to the semaphore. - fn drop_permit(&self, permit: &mut Self::Permit); - fn is_idle(&self) -> bool; fn add_permit(&self); - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Self::Permit, - ) -> Poll<Result<(), ClosedError>>; - - fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>; - - /// A value was sent into the channel and the permit held by `tx` is - /// dropped. In this case, the permit should not immeditely be returned to - /// the semaphore. Instead, the permit is returnred to the semaphore once - /// the sent value is read by the rx handle. - fn forget(&self, permit: &mut Self::Permit); - fn close(&self); + + fn is_closed(&self) -> bool; } struct Chan<T, S> { + /// Notifies all tasks listening for the receiver being dropped + notify_rx_closed: Notify, + /// Handle to the push half of the lock-free list. tx: list::Tx<T>, @@ -153,13 +102,11 @@ impl<T> fmt::Debug for RxFields<T> { unsafe impl<T: Send, S: Send> Send for Chan<T, S> {} unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {} -pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) -where - S: Semaphore, -{ +pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) { let (tx, rx) = list::channel(); let chan = Arc::new(Chan { + notify_rx_closed: Notify::new(), tx, semaphore, rx_waker: AtomicWaker::new(), @@ -175,48 +122,60 @@ where // ===== impl Tx ===== -impl<T, S> Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Tx<T, S> { fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> { - Tx { - inner: chan, - permit: S::new_permit(), - } - } - - pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> { - self.inner.semaphore.poll_acquire(cx, &mut self.permit) + Tx { inner: chan } } - pub(crate) fn disarm(&mut self) { - // TODO: should this error if not acquired? - self.inner.semaphore.drop_permit(&mut self.permit) + pub(super) fn semaphore(&self) -> &S { + &self.inner.semaphore } /// Send a message and notify the receiver. - pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut self.permit) + pub(crate) fn send(&self, value: T) { + self.inner.send(value); } -} -impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> { - pub(crate) fn is_ready(&self) -> bool { - self.permit.is_acquired() + /// Wake the receive half + pub(crate) fn wake_rx(&self) { + self.inner.rx_waker.wake(); } } -impl<T> Tx<T, AtomicUsize> { - pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> { - self.inner.try_send(value, &mut ()) +impl<T, S: Semaphore> Tx<T, S> { + pub(crate) fn is_closed(&self) -> bool { + self.inner.semaphore.is_closed() + } + + pub(crate) async fn closed(&self) { + use std::future::Future; + use std::pin::Pin; + use std::task::Poll; + + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. Requesting the notification + // requires polling the future once. + let notified = self.inner.notify_rx_closed.notified(); + pin!(notified); + + // Polling the future once is guaranteed to return `Pending` as `watch` + // only notifies using `notify_waiters`. + crate::future::poll_fn(|cx| { + let res = Pin::new(&mut notified).poll(cx); + assert!(!res.is_ready()); + Poll::Ready(()) + }) + .await; + + if self.inner.semaphore.is_closed() { + return; + } + notified.await; } } -impl<T, S> Clone for Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Clone for Tx<T, S> { fn clone(&self) -> Tx<T, S> { // Using a Relaxed ordering here is sufficient as the caller holds a // strong ref to `self`, preventing a concurrent decrement to zero. @@ -224,18 +183,12 @@ where Tx { inner: self.inner.clone(), - permit: S::new_permit(), } } } -impl<T, S> Drop for Tx<T, S> -where - S: Semaphore, -{ +impl<T, S> Drop for Tx<T, S> { fn drop(&mut self) { - self.inner.semaphore.drop_permit(&mut self.permit); - if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 { return; } @@ -244,16 +197,13 @@ where self.inner.tx.close(); // Notify the receiver - self.inner.rx_waker.wake(); + self.wake_rx(); } } // ===== impl Rx ===== -impl<T, S> Rx<T, S> -where - S: Semaphore, -{ +impl<T, S: Semaphore> Rx<T, S> { fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> { Rx { inner: chan } } @@ -270,6 +220,7 @@ where }); self.inner.semaphore.close(); + self.inner.notify_rx_closed.notify_waiters(); } /// Receive the next value @@ -341,10 +292,7 @@ where } } -impl<T, S> Drop for Rx<T, S> -where - S: Semaphore, -{ +impl<T, S: Semaphore> Drop for Rx<T, S> { fn drop(&mut self) { use super::block::Read::Value; @@ -362,25 +310,13 @@ where // ===== impl Chan ===== -impl<T, S> Chan<T, S> -where - S: Semaphore, -{ - fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> { - if let Err(e) = self.semaphore.try_acquire(permit) { - return Err((value, e)); - } - +impl<T, S> Chan<T, S> { + fn send(&self, value: T) { // Push the value self.tx.push(value); // Notify the rx task self.rx_waker.wake(); - - // Release the permit - self.semaphore.forget(permit); - - Ok(()) } } @@ -399,72 +335,24 @@ impl<T, S> Drop for Chan<T, S> { } } -use crate::sync::semaphore_ll::TryAcquireError; - -impl From<TryAcquireError> for TrySendError { - fn from(src: TryAcquireError) -> TrySendError { - if src.is_closed() { - TrySendError::Closed - } else if src.is_no_permits() { - TrySendError::Full - } else { - unreachable!(); - } - } -} - // ===== impl Semaphore for (::Semaphore, capacity) ===== -use crate::sync::semaphore_ll::Permit; - -impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) { - type Permit = Permit; - - fn new_permit() -> Permit { - Permit::new() - } - - fn drop_permit(&self, permit: &mut Permit) { - permit.release(1, &self.0); - } - +impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) { fn add_permit(&self) { - self.0.add_permits(1) + self.0.release(1) } fn is_idle(&self) -> bool { self.0.available_permits() == self.1 } - fn poll_acquire( - &self, - cx: &mut Context<'_>, - permit: &mut Permit, - ) -> Poll<Result<(), ClosedError>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - permit - .poll_acquire(cx, 1, &self.0) - .map_err(|_| ClosedError::new()) - .map(move |r| { - coop.made_progress(); - r - }) - } - - fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> { - permit.try_acquire(1, &self.0)?; - Ok(()) - } - - fn forget(&self, permit: &mut Self::Permit) { - permit.forget(1); - } - fn close(&self) { self.0.close(); } + + fn is_closed(&self) -> bool { + self.0.is_closed() + } } // ===== impl Semaphore for AtomicUsize ===== @@ -473,12 +361,6 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use std::usize; impl Semaphore for AtomicUsize { - type Permit = (); - - fn new_permit() {} - - fn drop_permit(&self, _permit: &mut ()) {} - fn add_permit(&self) { let prev = self.fetch_sub(2, Release); @@ -492,40 +374,11 @@ impl Semaphore for AtomicUsize { self.load(Acquire) >> 1 == 0 } - fn poll_acquire( - &self, - _cx: &mut Context<'_>, - permit: &mut (), - ) -> Poll<Result<(), ClosedError>> { - Ready(self.try_acquire(permit).map_err(|_| ClosedError::new())) - } - - fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> { - let mut curr = self.load(Acquire); - - loop { - if curr & 1 == 1 { - return Err(TrySendError::Closed); - } - - if curr == usize::MAX ^ 1 { - // Overflowed the ref count. There is no safe way to recover, so - // abort the process. In practice, this should never happen. - process::abort() - } - - match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) { - Ok(_) => return Ok(()), - Err(actual) => { - curr = actual; - } - } - } - } - - fn forget(&self, _permit: &mut ()) {} - fn close(&self) { self.fetch_or(1, Release); } + + fn is_closed(&self) -> bool { + self.load(Acquire) & 1 == 1 + } } diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index 72c42aa..7705452 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError { impl Error for TryRecvError {} -// ===== ClosedError ===== - -/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready). -#[derive(Debug)] -pub struct ClosedError(()); - -impl ClosedError { - pub(crate) fn new() -> ClosedError { - ClosedError(()) - } -} - -impl fmt::Display for ClosedError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "channel closed") - } -} - -impl Error for ClosedError {} - cfg_time! { // ===== SendTimeoutError ===== diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs index 53f82a2..2f4c532 100644 --- a/src/sync/mpsc/list.rs +++ b/src/sync/mpsc/list.rs @@ -1,9 +1,7 @@ //! A concurrent, lock-free, FIFO list. -use crate::loom::{ - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; +use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; +use crate::loom::thread; use crate::sync::mpsc::block::{self, Block}; use std::fmt; diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs index c489c9f..a2bcf83 100644 --- a/src/sync/mpsc/mod.rs +++ b/src/sync/mpsc/mod.rs @@ -1,23 +1,29 @@ #![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] -//! A multi-producer, single-consumer queue for sending values across +//! A multi-producer, single-consumer queue for sending values between //! asynchronous tasks. //! -//! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`] -//! handles. [`Receiver`] implements `Stream` and allows a task to read values -//! out of the channel. If there is no message to read, the current task will be -//! notified when a new value is sent. If the channel is at capacity, the send -//! is rejected and the task will be notified when additional capacity is -//! available. In other words, the channel provides backpressure. -//! //! This module provides two variants of the channel: bounded and unbounded. The //! bounded variant has a limit on the number of messages that the channel can //! store, and if this limit is reached, trying to send another message will //! wait until a message is received from the channel. An unbounded channel has -//! an infinite capacity, so the `send` method never does any kind of sleeping. +//! an infinite capacity, so the `send` method will always complete immediately. //! This makes the [`UnboundedSender`] usable from both synchronous and //! asynchronous code. //! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and +//! [`UnboundedReceiver`] for the unbounded channel. Both [`Receiver`] and +//! [`UnboundedReceiver`] implement [`Stream`] and allow a task to read +//! values out of the channel. If there is no message to read, the current task +//! will be notified when a new value is sent. [`Sender`] and +//! [`UnboundedSender`] allow sending values into the channel. If the bounded +//! channel is at capacity, the send is rejected and the task will be notified +//! when additional capacity is available. In other words, the channel provides +//! backpressure. +//! +//! //! # Disconnection //! //! When all [`Sender`] handles have been dropped, it is no longer @@ -43,11 +49,10 @@ //! are two situations to consider: //! //! **Bounded channel**: If you need a bounded channel, you should use a bounded -//! Tokio `mpsc` channel for both directions of communication. To call the async -//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you -//! will need to use [`Handle::block_on`], which allow you to execute an async -//! method in synchronous code. This is necessary because a bounded channel may -//! need to wait for additional capacity to become available. +//! Tokio `mpsc` channel for both directions of communication. Instead of calling +//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in +//! synchronous code you will need to use the [`blocking_send`][blocking-send] or +//! [`blocking_recv`][blocking-recv] methods. //! //! **Unbounded channel**: You should use the kind of channel that matches where //! the receiver is. So for sending a message _from async to sync_, you should @@ -57,9 +62,13 @@ //! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver +//! [`Stream`]: crate::stream::Stream //! [bounded-send]: crate::sync::mpsc::Sender::send() //! [bounded-recv]: crate::sync::mpsc::Receiver::recv() +//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send() +//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv() //! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender +//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver //! [`Handle::block_on`]: crate::runtime::Handle::block_on() //! [std-unbounded]: std::sync::mpsc::channel //! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html @@ -67,7 +76,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, Receiver, Sender}; +pub use self::bounded::{channel, Permit, Receiver, Sender}; mod chan; diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index 1b2288a..fe882d5 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -47,7 +47,7 @@ impl<T> fmt::Debug for UnboundedReceiver<T> { } /// Creates an unbounded mpsc channel for communicating between asynchronous -/// tasks. +/// tasks without backpressure. /// /// A `send` on this channel will always succeed as long as the receive half has /// not been closed. If the receiver falls behind, messages will be arbitrarily @@ -73,8 +73,7 @@ impl<T> UnboundedReceiver<T> { UnboundedReceiver { chan } } - #[doc(hidden)] // TODO: doc - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -174,7 +173,97 @@ impl<T> UnboundedSender<T> { /// [`close`]: UnboundedReceiver::close /// [`UnboundedReceiver`]: UnboundedReceiver pub fn send(&self, message: T) -> Result<(), SendError<T>> { - self.chan.send_unbounded(message)?; + if !self.inc_num_messages() { + return Err(SendError(message)); + } + + self.chan.send(message); Ok(()) } + + fn inc_num_messages(&self) -> bool { + use std::process; + use std::sync::atomic::Ordering::{AcqRel, Acquire}; + + let mut curr = self.chan.semaphore().load(Acquire); + + loop { + if curr & 1 == 1 { + return false; + } + + if curr == usize::MAX ^ 1 { + // Overflowed the ref count. There is no safe way to recover, so + // abort the process. In practice, this should never happen. + process::abort() + } + + match self + .chan + .semaphore() + .compare_exchange(curr, curr + 2, AcqRel, Acquire) + { + Ok(_) => return true, + Err(actual) => { + curr = actual; + } + } + } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx1, rx) = mpsc::unbounded_channel::<()>(); + /// let tx2 = tx1.clone(); + /// let tx3 = tx1.clone(); + /// let tx4 = tx1.clone(); + /// let tx5 = tx1.clone(); + /// tokio::spawn(async move { + /// drop(rx); + /// }); + /// + /// futures::join!( + /// tx1.closed(), + /// tx2.closed(), + /// tx3.closed(), + /// tx4.closed(), + /// tx5.closed() + /// ); + //// println!("Receiver dropped"); + /// } + /// ``` + pub async fn closed(&self) { + self.chan.closed().await + } + /// Checks if the channel has been closed. This happens when the + /// [`UnboundedReceiver`] is dropped, or when the + /// [`UnboundedReceiver::close`] method is called. + /// + /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver + /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close + /// + /// ``` + /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + /// assert!(!tx.is_closed()); + /// + /// let tx2 = tx.clone(); + /// assert!(!tx2.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// assert!(tx2.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.chan.is_closed() + } } diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index 642058b..21e44ca 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; @@ -115,7 +117,6 @@ use std::sync::Arc; /// [`std::sync::Mutex`]: struct@std::sync::Mutex /// [`Send`]: trait@std::marker::Send /// [`lock`]: method@Mutex::lock -#[derive(Debug)] pub struct Mutex<T: ?Sized> { s: semaphore::Semaphore, c: UnsafeCell<T>, @@ -220,6 +221,27 @@ impl<T: ?Sized> Mutex<T> { } } + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// static LOCK: Mutex<i32> = Mutex::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(t: T) -> Self + where + T: Sized, + { + Self { + c: UnsafeCell::new(t), + s: semaphore::Semaphore::const_new(1), + } + } + /// Locks this mutex, causing the current task /// to yield until the lock has been acquired. /// When the lock has been acquired, function returns a [`MutexGuard`]. @@ -305,6 +327,30 @@ impl<T: ?Sized> Mutex<T> { } } + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// + /// fn main() { + /// let mut mutex = Mutex::new(1); + /// + /// let n = mutex.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } + } + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock /// is currently held somewhere else. /// @@ -373,6 +419,20 @@ where } } +impl<T> std::fmt::Debug for Mutex<T> +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("Mutex"); + match self.try_lock() { + Ok(inner) => d.field("data", &*inner), + Err(_) => d.field("data", &format_args!("<locked>")), + }; + d.finish() + } +} + // === impl MutexGuard === impl<T: ?Sized> Drop for MutexGuard<'_, T> { diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 5cb41e8..922f109 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -1,3 +1,10 @@ +// Allow `unreachable_pub` warnings when sync is not enabled +// due to the usage of `Notify` within the `rt` feature set. +// When this module is compiled with `sync` enabled we will warn on +// this lint. When `rt` is enabled we use `pub(crate)` which +// triggers this warning but it is safe to ignore in this case. +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + use crate::loom::sync::atomic::AtomicU8; use crate::loom::sync::Mutex; use crate::util::linked_list::{self, LinkedList}; @@ -10,6 +17,8 @@ use std::ptr::NonNull; use std::sync::atomic::Ordering::SeqCst; use std::task::{Context, Poll, Waker}; +type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; + /// Notify a single task to wake up. /// /// `Notify` provides a basic mechanism to notify a single task of an event. @@ -17,20 +26,20 @@ use std::task::{Context, Poll, Waker}; /// another task to perform an operation. /// /// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. -/// [`notified().await`] waits for a permit to become available, and [`notify()`] +/// [`notified().await`] waits for a permit to become available, and [`notify_one()`] /// sets a permit **if there currently are no available permits**. /// /// The synchronization details of `Notify` are similar to /// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`] /// value contains a single permit. [`notified().await`] waits for the permit to -/// be made available, consumes the permit, and resumes. [`notify()`] sets the +/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the /// permit, waking a pending task if there is one. /// -/// If `notify()` is called **before** `notfied().await`, then the next call to +/// If `notify_one()` is called **before** `notified().await`, then the next call to /// `notified().await` will complete immediately, consuming the permit. Any /// subsequent calls to `notified().await` will wait for a new permit. /// -/// If `notify()` is called **multiple** times before `notified().await`, only a +/// If `notify_one()` is called **multiple** times before `notified().await`, only a /// **single** permit is stored. The next call to `notified().await` will /// complete immediately, but the one after will wait for a new permit. /// @@ -53,7 +62,7 @@ use std::task::{Context, Poll, Waker}; /// }); /// /// println!("sending notification"); -/// notify.notify(); +/// notify.notify_one(); /// } /// ``` /// @@ -76,7 +85,7 @@ use std::task::{Context, Poll, Waker}; /// .push_back(value); /// /// // Notify the consumer a value is available -/// self.notify.notify(); +/// self.notify.notify_one(); /// } /// /// pub async fn recv(&self) -> T { @@ -96,12 +105,20 @@ use std::task::{Context, Poll, Waker}; /// [park]: std::thread::park /// [unpark]: std::thread::Thread::unpark /// [`notified().await`]: Notify::notified() -/// [`notify()`]: Notify::notify() +/// [`notify_one()`]: Notify::notify_one() /// [`Semaphore`]: crate::sync::Semaphore #[derive(Debug)] pub struct Notify { state: AtomicU8, - waiters: Mutex<LinkedList<Waiter>>, + waiters: Mutex<WaitList>, +} + +#[derive(Debug, Clone, Copy)] +enum NotificationType { + // Notification triggered by calling `notify_waiters` + AllWaiters, + // Notification triggered by calling `notify_one` + OneWaiter, } #[derive(Debug)] @@ -113,7 +130,7 @@ struct Waiter { waker: Option<Waker>, /// `true` if the notification has been assigned to this waiter. - notified: bool, + notified: Option<NotificationType>, /// Should not be `Unpin`. _p: PhantomPinned, @@ -121,7 +138,7 @@ struct Waiter { /// Future returned from `notified()` #[derive(Debug)] -struct Notified<'a> { +pub struct Notified<'a> { /// The `Notify` being received on. notify: &'a Notify, @@ -168,14 +185,38 @@ impl Notify { } } + /// Create a new `Notify`, initialized without a permit. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Notify; + /// + /// static NOTIFY: Notify = Notify::const_new(); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Notify { + Notify { + state: AtomicU8::new(0), + waiters: Mutex::const_new(LinkedList::new()), + } + } + /// Wait for a notification. /// + /// Equivalent to: + /// + /// ```ignore + /// async fn notified(&self); + /// ``` + /// /// Each `Notify` value holds a single permit. If a permit is available from - /// an earlier call to [`notify()`], then `notified().await` will complete + /// an earlier call to [`notify_one()`], then `notified().await` will complete /// immediately, consuming that permit. Otherwise, `notified().await` waits - /// for a permit to be made available by the next call to `notify()`. + /// for a permit to be made available by the next call to `notify_one()`. /// - /// [`notify()`]: Notify::notify + /// [`notify_one()`]: Notify::notify_one /// /// # Examples /// @@ -194,21 +235,20 @@ impl Notify { /// }); /// /// println!("sending notification"); - /// notify.notify(); + /// notify.notify_one(); /// } /// ``` - pub async fn notified(&self) { + pub fn notified(&self) -> Notified<'_> { Notified { notify: self, state: State::Init, waiter: UnsafeCell::new(Waiter { pointers: linked_list::Pointers::new(), waker: None, - notified: false, + notified: None, _p: PhantomPinned, }), } - .await } /// Notifies a waiting task @@ -216,10 +256,10 @@ impl Notify { /// If a task is currently waiting, that task is notified. Otherwise, a /// permit is stored in this `Notify` value and the **next** call to /// [`notified().await`] will complete immediately consuming the permit made - /// available by this call to `notify()`. + /// available by this call to `notify_one()`. /// /// At most one permit may be stored by `Notify`. Many sequential calls to - /// `notify` will result in a single permit being stored. The next call to + /// `notify_one` will result in a single permit being stored. The next call to /// `notified().await` will complete immediately, but the one after that /// will wait. /// @@ -242,10 +282,10 @@ impl Notify { /// }); /// /// println!("sending notification"); - /// notify.notify(); + /// notify.notify_one(); /// } /// ``` - pub fn notify(&self) { + pub fn notify_one(&self) { // Load the current state let mut curr = self.state.load(SeqCst); @@ -266,7 +306,7 @@ impl Notify { } // There are waiters, the lock must be acquired to notify. - let mut waiters = self.waiters.lock().unwrap(); + let mut waiters = self.waiters.lock(); // The state must be reloaded while the lock is held. The state may only // transition out of WAITING while the lock is held. @@ -277,6 +317,45 @@ impl Notify { waker.wake(); } } + + /// Notifies all waiting tasks + pub(crate) fn notify_waiters(&self) { + // There are waiters, the lock must be acquired to notify. + let mut waiters = self.waiters.lock(); + + // The state must be reloaded while the lock is held. The state may only + // transition out of WAITING while the lock is held. + let curr = self.state.load(SeqCst); + + if let EMPTY | NOTIFIED = curr { + // There are no waiting tasks. In this case, no synchronization is + // established between `notify` and `notified().await`. + return; + } + + // At this point, it is guaranteed that the state will not + // concurrently change, as holding the lock is required to + // transition **out** of `WAITING`. + // + // Get pending waiters + while let Some(mut waiter) = waiters.pop_back() { + // Safety: `waiters` lock is still held. + let waiter = unsafe { waiter.as_mut() }; + + assert!(waiter.notified.is_none()); + + waiter.notified = Some(NotificationType::AllWaiters); + + if let Some(waker) = waiter.waker.take() { + waker.wake(); + } + } + + // All waiters have been notified, the state must be transitioned to + // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be + // held, a `store` is sufficient. + self.state.store(EMPTY, SeqCst); + } } impl Default for Notify { @@ -285,7 +364,7 @@ impl Default for Notify { } } -fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> { +fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<Waker> { loop { match curr { EMPTY | NOTIFIED => { @@ -311,9 +390,9 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) - // Safety: `waiters` lock is still held. let waiter = unsafe { waiter.as_mut() }; - assert!(!waiter.notified); + assert!(waiter.notified.is_none()); - waiter.notified = true; + waiter.notified = Some(NotificationType::OneWaiter); let waker = waiter.waker.take(); if waiters.is_empty() { @@ -373,7 +452,7 @@ impl Future for Notified<'_> { // Acquire the lock and attempt to transition to the waiting // state. - let mut waiters = notify.waiters.lock().unwrap(); + let mut waiters = notify.waiters.lock(); // Reload the state with the lock held let mut curr = notify.state.load(SeqCst); @@ -428,6 +507,8 @@ impl Future for Notified<'_> { waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); *state = Waiting; + + return Poll::Pending; } Waiting => { // Currently in the "Waiting" state, implying the caller has @@ -435,16 +516,16 @@ impl Future for Notified<'_> { // `notify.waiters`). In order to access the waker fields, // we must hold the lock. - let waiters = notify.waiters.lock().unwrap(); + let waiters = notify.waiters.lock(); // Safety: called while locked let w = unsafe { &mut *waiter.get() }; - if w.notified { + if w.notified.is_some() { // Our waker has been notified. Reset the fields and // remove it from the list. w.waker = None; - w.notified = false; + w.notified = None; *state = Done; } else { @@ -483,12 +564,12 @@ impl Drop for Notified<'_> { // longer stored in the linked list. if let Waiting = *state { let mut notify_state = WAITING; - let mut waiters = notify.waiters.lock().unwrap(); + let mut waiters = notify.waiters.lock(); // `Notify.state` may be in any of the three states (Empty, Waiting, // Notified). It doesn't actually matter what the atomic is set to // at this point. We hold the lock and will ensure the atomic is in - // the correct state once th elock is dropped. + // the correct state once the lock is dropped. // // Because the atomic state is not checked, at first glance, it may // seem like this routine does not handle the case where the @@ -516,14 +597,13 @@ impl Drop for Notified<'_> { notify.state.store(EMPTY, SeqCst); } - // See if the node was notified but not received. In this case, the - // notification must be sent to another waiter. + // See if the node was notified but not received. In this case, if + // the notification was triggered via `notify_one`, it must be sent + // to the next waiter. // // Safety: with the entry removed from the linked list, there can be // no concurrent access to the entry - let notified = unsafe { (*waiter.get()).notified }; - - if notified { + if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } { if let Some(waker) = notify_locked(&mut waiters, ¬ify.state, notify_state) { drop(waiters); waker.wake(); diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs index 17767e7..951ab71 100644 --- a/src/sync/oneshot.rs +++ b/src/sync/oneshot.rs @@ -124,7 +124,6 @@ struct State(usize); /// } /// ``` pub fn channel<T>() -> (Sender<T>, Receiver<T>) { - #[allow(deprecated)] let inner = Arc::new(Inner { state: AtomicUsize::new(State::new().as_usize()), value: UnsafeCell::new(None), @@ -197,8 +196,7 @@ impl<T> Sender<T> { Ok(()) } - #[doc(hidden)] // TODO: remove - pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { + fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> { // Keep track of task budget let coop = ready!(crate::coop::poll_proceed(cx)); diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 3d2a2f7..a84c4c1 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -1,5 +1,8 @@ -use crate::sync::batch_semaphore::{AcquireError, Semaphore}; +use crate::sync::batch_semaphore::Semaphore; use std::cell::UnsafeCell; +use std::fmt; +use std::marker; +use std::mem; use std::ops; #[cfg(not(loom))] @@ -8,7 +11,7 @@ const MAX_READS: usize = 32; #[cfg(loom)] const MAX_READS: usize = 10; -/// An asynchronous reader-writer lock +/// An asynchronous reader-writer lock. /// /// This type of lock allows a number of readers or at most one writer at any /// point in time. The write portion of this lock typically allows modification @@ -83,10 +86,140 @@ pub struct RwLock<T: ?Sized> { /// [`RwLock`]. /// /// [`read`]: method@RwLock::read -#[derive(Debug)] +/// [`RwLock`]: struct@RwLock pub struct RwLockReadGuard<'a, T: ?Sized> { - permit: ReleasingPermit<'a, T>, - lock: &'a RwLock<T>, + s: &'a Semaphore, + data: *const T, + marker: marker::PhantomData<&'a T>, +} + +impl<'a, T> RwLockReadGuard<'a, T> { + /// Make a new `RwLockReadGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockReadGuard::map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::map(guard, |f| &f.0); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U> + where + F: FnOnce(&T) -> &U, + { + let data = f(&*this) as *const U; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`RwLockReadGuard`] for a component of the + /// locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockReadGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the + /// same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockReadGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// let guard = lock.read().await; + /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail"); + /// + /// assert_eq!(1, *guard); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self> + where + F: FnOnce(&T) -> Option<&U>, + { + let data = match f(&*this) { + Some(data) => data as *const U, + None => return Err(this), + }; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1); + } } /// RAII structure used to release the exclusive write access of a lock when @@ -97,32 +230,195 @@ pub struct RwLockReadGuard<'a, T: ?Sized> { /// /// [`write`]: method@RwLock::write /// [`RwLock`]: struct@RwLock -#[derive(Debug)] pub struct RwLockWriteGuard<'a, T: ?Sized> { - permit: ReleasingPermit<'a, T>, - lock: &'a RwLock<T>, + s: &'a Semaphore, + data: *mut T, + marker: marker::PhantomData<&'a mut T>, } -// Wrapper arround Permit that releases on Drop -#[derive(Debug)] -struct ReleasingPermit<'a, T: ?Sized> { - num_permits: u16, - lock: &'a RwLock<T>, +impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { + /// Make a new `RwLockWriteGuard` for a component of the locked data. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be used as + /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of + /// the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the + /// [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockWriteGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + RwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`RwLockWriteGuard`] for a component of + /// the locked data. The original guard is returned if the closure returns + /// `None`. + /// + /// This operation cannot fail as the `RwLockWriteGuard` passed in already + /// locked the data. + /// + /// This is an associated function that needs to be + /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with + /// methods of the same name on the contents of the locked data. + /// + /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from + /// the [`parking_lot` crate]. + /// + /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map + /// [`parking_lot` crate]: https://crates.io/crates/parking_lot + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{RwLock, RwLockWriteGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let lock = RwLock::new(Foo(1)); + /// + /// { + /// let guard = lock.write().await; + /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); + /// *guard = 2; + /// } + /// + /// assert_eq!(Foo(2), *lock.read().await); + /// # } + /// ``` + #[inline] + pub fn try_map<F, U: ?Sized>(mut this: Self, f: F) -> Result<RwLockWriteGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + // NB: Forget to avoid drop impl from being called. + mem::forget(this); + Ok(RwLockWriteGuard { + s, + data, + marker: marker::PhantomData, + }) + } + + /// Atomically downgrades a write lock into a read lock without allowing + /// any writers to take exclusive access of the lock in the meantime. + /// + /// **Note:** This won't *necessarily* allow any additional readers to acquire + /// locks, since [`RwLock`] is fair and it is possible that a writer is next + /// in line. + /// + /// Returns an RAII guard which will drop the read access of this rwlock + /// when dropped. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::RwLock; + /// # use std::sync::Arc; + /// # + /// # #[tokio::main] + /// # async fn main() { + /// let lock = Arc::new(RwLock::new(1)); + /// + /// let n = lock.write().await; + /// + /// let cloned_lock = lock.clone(); + /// let handle = tokio::spawn(async move { + /// *cloned_lock.write().await = 2; + /// }); + /// + /// let n = n.downgrade(); + /// assert_eq!(*n, 1, "downgrade is atomic"); + /// + /// assert_eq!(*lock.read().await, 1, "additional readers can obtain locks"); + /// + /// drop(n); + /// handle.await.unwrap(); + /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock"); + /// # } + /// ``` + /// + /// [`RwLock`]: struct@RwLock + pub fn downgrade(self) -> RwLockReadGuard<'a, T> { + let RwLockWriteGuard { s, data, .. } = self; + + // Release all but one of the permits held by the write guard + s.release(MAX_READS - 1); + + RwLockReadGuard { + s, + data, + marker: marker::PhantomData, + } + } } -impl<'a, T: ?Sized> ReleasingPermit<'a, T> { - async fn acquire( - lock: &'a RwLock<T>, - num_permits: u16, - ) -> Result<ReleasingPermit<'a, T>, AcquireError> { - lock.s.acquire(num_permits.into()).await?; - Ok(Self { num_permits, lock }) +impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T> +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) } } -impl<T: ?Sized> Drop for ReleasingPermit<'_, T> { +impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T> +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { fn drop(&mut self) { - self.lock.s.release(self.num_permits as usize); + self.s.release(MAX_READS); } } @@ -139,9 +435,11 @@ fn bounds() { check_sync::<RwLock<u32>>(); check_unpin::<RwLock<u32>>(); + check_send::<RwLockReadGuard<'_, u32>>(); check_sync::<RwLockReadGuard<'_, u32>>(); check_unpin::<RwLockReadGuard<'_, u32>>(); + check_send::<RwLockWriteGuard<'_, u32>>(); check_sync::<RwLockWriteGuard<'_, u32>>(); check_unpin::<RwLockWriteGuard<'_, u32>>(); @@ -155,8 +453,17 @@ fn bounds() { // RwLock<T>. unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {} unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {} +// NB: These impls need to be explicit since we're storing a raw pointer. +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send`. +unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {} unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {} unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} +// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over +// `T` is `Send` - but since this is also provides mutable access, we need to +// make sure that `T` is `Send` since its value can be sent across thread +// boundaries. +unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {} impl<T: ?Sized> RwLock<T> { /// Creates a new instance of an `RwLock<T>` which is unlocked. @@ -178,6 +485,27 @@ impl<T: ?Sized> RwLock<T> { } } + /// Creates a new instance of an `RwLock<T>` which is unlocked. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// static LOCK: RwLock<i32> = RwLock::const_new(5); + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(value: T) -> RwLock<T> + where + T: Sized, + { + RwLock { + c: UnsafeCell::new(value), + s: Semaphore::const_new(MAX_READS), + } + } + /// Locks this rwlock with shared read access, causing the current task /// to yield until the lock has been acquired. /// @@ -210,12 +538,16 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| { + self.s.acquire(1).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); - RwLockReadGuard { lock: self, permit } + RwLockReadGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + } } /// Locks this rwlock with exclusive write access, causing the current task @@ -241,15 +573,40 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - let permit = ReleasingPermit::acquire(self, MAX_READS as u16) - .await - .unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. - unreachable!() - }); - - RwLockWriteGuard { lock: self, permit } + self.s.acquire(MAX_READS as u32).await.unwrap_or_else(|_| { + // The semaphore was closed. but, we never explicitly close it, and we have a + // handle to it through the Arc, which means that this can never happen. + unreachable!() + }); + RwLockWriteGuard { + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + } + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `RwLock` mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::RwLock; + /// + /// fn main() { + /// let mut lock = RwLock::new(1); + /// + /// let n = lock.get_mut(); + /// *n = 2; + /// } + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { + // Safety: This is https://github.com/rust-lang/rust/pull/76936 + &mut *self.c.get() + } } /// Consumes the lock, returning the underlying data. @@ -265,7 +622,7 @@ impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.lock.c.get() } + unsafe { &*self.data } } } @@ -273,13 +630,13 @@ impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.lock.c.get() } + unsafe { &*self.data } } } impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.lock.c.get() } + unsafe { &mut *self.data } } } @@ -289,7 +646,7 @@ impl<T> From<T> for RwLock<T> { } } -impl<T> Default for RwLock<T> +impl<T: ?Sized> Default for RwLock<T> where T: Default, { diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index 2489d34..43dd976 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -1,7 +1,7 @@ use super::batch_semaphore as ll; // low level implementation use std::sync::Arc; -/// Counting semaphore performing asynchronous permit aquisition. +/// Counting semaphore performing asynchronous permit acquisition. /// /// A semaphore maintains a set of permits. Permits are used to synchronize /// access to a shared resource. A semaphore differs from a mutex in that it @@ -74,6 +74,15 @@ impl Semaphore { } } + /// Creates a new semaphore with the initial number of permits. + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new(permits: usize) -> Self { + Self { + ll_sem: ll::Semaphore::const_new(permits), + } + } + /// Returns the current number of available permits. pub fn available_permits(&self) -> usize { self.ll_sem.available_permits() @@ -114,7 +123,7 @@ impl Semaphore { pub async fn acquire_owned(self: Arc<Self>) -> OwnedSemaphorePermit { self.ll_sem.acquire(1).await.unwrap(); OwnedSemaphorePermit { - sem: self.clone(), + sem: self, permits: 1, } } @@ -127,7 +136,7 @@ impl Semaphore { pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(OwnedSemaphorePermit { - sem: self.clone(), + sem: self, permits: 1, }), Err(_) => Err(TryAcquireError(())), diff --git a/src/sync/semaphore_ll.rs b/src/sync/semaphore_ll.rs deleted file mode 100644 index 25d25ac..0000000 --- a/src/sync/semaphore_ll.rs +++ /dev/null @@ -1,1221 +0,0 @@ -#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))] - -//! Thread-safe, asynchronous counting semaphore. -//! -//! A `Semaphore` instance holds a set of permits. Permits are used to -//! synchronize access to a shared resource. -//! -//! Before accessing the shared resource, callers acquire a permit from the -//! semaphore. Once the permit is acquired, the caller then enters the critical -//! section. If no permits are available, then acquiring the semaphore returns -//! `Pending`. The task is woken once a permit becomes available. - -use crate::loom::cell::UnsafeCell; -use crate::loom::future::AtomicWaker; -use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; - -use std::cmp; -use std::fmt; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; -use std::usize; - -/// Futures-aware semaphore. -pub(crate) struct Semaphore { - /// Tracks both the waiter queue tail pointer and the number of remaining - /// permits. - state: AtomicUsize, - - /// waiter queue head pointer. - head: UnsafeCell<NonNull<Waiter>>, - - /// Coordinates access to the queue head. - rx_lock: AtomicUsize, - - /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box<Waiter>, -} - -/// A semaphore permit -/// -/// Tracks the lifecycle of a semaphore permit. -/// -/// An instance of `Permit` is intended to be used with a **single** instance of -/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore -/// instances will result in unexpected behavior. -/// -/// `Permit` does **not** release the permit back to the semaphore on drop. It -/// is the user's responsibility to ensure that `Permit::release` is called -/// before dropping the permit. -#[derive(Debug)] -pub(crate) struct Permit { - waiter: Option<Box<Waiter>>, - state: PermitState, -} - -/// Error returned by `Permit::poll_acquire`. -#[derive(Debug)] -pub(crate) struct AcquireError(()); - -/// Error returned by `Permit::try_acquire`. -#[derive(Debug)] -pub(crate) enum TryAcquireError { - Closed, - NoPermits, -} - -/// Node used to notify the semaphore waiter when permit is available. -#[derive(Debug)] -struct Waiter { - /// Stores waiter state. - /// - /// See `WaiterState` for more details. - state: AtomicUsize, - - /// Task to wake when a permit is made available. - waker: AtomicWaker, - - /// Next pointer in the queue of waiting senders. - next: AtomicPtr<Waiter>, -} - -/// Semaphore state -/// -/// The 2 low bits track the modes. -/// -/// - Closed -/// - Full -/// -/// When not full, the rest of the `usize` tracks the total number of messages -/// in the channel. When full, the rest of the `usize` is a pointer to the tail -/// of the "waiting senders" queue. -#[derive(Copy, Clone)] -struct SemState(usize); - -/// Permit state -#[derive(Debug, Copy, Clone)] -enum PermitState { - /// Currently waiting for permits to be made available and assigned to the - /// waiter. - Waiting(u16), - - /// The number of acquired permits - Acquired(u16), -} - -/// State for an individual waker node -#[derive(Debug, Copy, Clone)] -struct WaiterState(usize); - -/// Waiter node is in the semaphore queue -const QUEUED: usize = 0b001; - -/// Semaphore has been closed, no more permits will be issued. -const CLOSED: usize = 0b10; - -/// The permit that owns the `Waiter` dropped. -const DROPPED: usize = 0b100; - -/// Represents "one requested permit" in the waiter state -const PERMIT_ONE: usize = 0b1000; - -/// Masks the waiter state to only contain bits tracking number of requested -/// permits. -const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1); - -/// How much to shift a permit count to pack it into the waker state -const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros(); - -/// Flag differentiating between available permits and waiter pointers. -/// -/// If we assume pointers are properly aligned, then the least significant bit -/// will always be zero. So, we use that bit to track if the value represents a -/// number. -const NUM_FLAG: usize = 0b01; - -/// Signal the semaphore is closed -const CLOSED_FLAG: usize = 0b10; - -/// Maximum number of permits a semaphore can manage -const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; - -/// When representing "numbers", the state has to be shifted this much (to get -/// rid of the flag bit). -const NUM_SHIFT: usize = 2; - -// ===== impl Semaphore ===== - -impl Semaphore { - /// Creates a new semaphore with the initial number of permits - /// - /// # Panics - /// - /// Panics if `permits` is zero. - pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(Waiter::new()); - let ptr = NonNull::from(&*stub); - - // Allocations are aligned - debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); - - let state = SemState::new(permits, &stub); - - Semaphore { - state: AtomicUsize::new(state.to_usize()), - head: UnsafeCell::new(ptr), - rx_lock: AtomicUsize::new(0), - stub, - } - } - - /// Returns the current number of available permits - pub(crate) fn available_permits(&self) -> usize { - let curr = SemState(self.state.load(Acquire)); - curr.available_permits() - } - - /// Tries to acquire the requested number of permits, registering the waiter - /// if not enough permits are available. - fn poll_acquire( - &self, - cx: &mut Context<'_>, - num_permits: u16, - permit: &mut Permit, - ) -> Poll<Result<(), AcquireError>> { - self.poll_acquire2(num_permits, || { - let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new())); - - waiter.waker.register_by_ref(cx.waker()); - - Some(NonNull::from(&**waiter)) - }) - } - - fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> { - match self.poll_acquire2(num_permits, || None) { - Poll::Ready(res) => res.map_err(to_try_acquire), - Poll::Pending => Err(TryAcquireError::NoPermits), - } - } - - /// Polls for a permit - /// - /// Tries to acquire available permits first. If unable to acquire a - /// sufficient number of permits, the caller's waiter is pushed onto the - /// semaphore's wait queue. - fn poll_acquire2<F>( - &self, - num_permits: u16, - mut get_waiter: F, - ) -> Poll<Result<(), AcquireError>> - where - F: FnMut() -> Option<NonNull<Waiter>>, - { - let num_permits = num_permits as usize; - - // Load the current state - let mut curr = SemState(self.state.load(Acquire)); - - // Saves a ref to the waiter node - let mut maybe_waiter: Option<NonNull<Waiter>> = None; - - /// Used in branches where we attempt to push the waiter into the wait - /// queue but fail due to permits becoming available or the wait queue - /// transitioning to "closed". In this case, the waiter must be - /// transitioned back to the "idle" state. - macro_rules! revert_to_idle { - () => { - if let Some(waiter) = maybe_waiter { - unsafe { waiter.as_ref() }.revert_to_idle(); - } - }; - } - - loop { - let mut next = curr; - - if curr.is_closed() { - revert_to_idle!(); - return Ready(Err(AcquireError::closed())); - } - - let acquired = next.acquire_permits(num_permits, &self.stub); - - if !acquired { - // There are not enough available permits to satisfy the - // request. The permit transitions to a waiting state. - debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits); - - if let Some(waiter) = maybe_waiter.as_ref() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - w.set_permits_to_acquire(num_permits - curr.available_permits()); - } else { - // Get the waiter for the permit. - if let Some(waiter) = get_waiter() { - // Safety: the caller owns the waiter. - let w = unsafe { waiter.as_ref() }; - - // If there are any currently available permits, the - // waiter acquires those immediately and waits for the - // remaining permits to become available. - if !w.to_queued(num_permits - curr.available_permits()) { - // The node is alrady queued, there is no further work - // to do. - return Pending; - } - - maybe_waiter = Some(waiter); - } else { - // No waiter, this indicates the caller does not wish to - // "wait", so there is nothing left to do. - return Pending; - } - } - - next.set_waiter(maybe_waiter.unwrap()); - } - - debug_assert_ne!(curr.0, 0); - debug_assert_ne!(next.0, 0); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if acquired { - // Successfully acquire permits **without** queuing the - // waiter node. The waiter node is not currently in the - // queue. - revert_to_idle!(); - return Ready(Ok(())); - } else { - // The node is pushed into the queue, the final step is - // to set the node's "next" pointer to return the wait - // queue into a consistent state. - - let prev_waiter = - curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub)); - - let waiter = maybe_waiter.unwrap(); - - // Link the nodes. - // - // Safety: the mpsc algorithm guarantees the old tail of - // the queue is not removed from the queue during the - // push process. - unsafe { - prev_waiter.as_ref().store_next(waiter); - } - - return Pending; - } - } - Err(actual) => { - curr = SemState(actual); - } - } - } - } - - /// Closes the semaphore. This prevents the semaphore from issuing new - /// permits and notifies all pending waiters. - pub(crate) fn close(&self) { - // Acquire the `rx_lock`, setting the "closed" flag on the lock. - let prev = self.rx_lock.fetch_or(1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(0, true); - } - /// Adds `n` new permits to the semaphore. - /// - /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded. - pub(crate) fn add_permits(&self, n: usize) { - if n == 0 { - return; - } - - // TODO: Handle overflow. A panic is not sufficient, the process must - // abort. - let prev = self.rx_lock.fetch_add(n << 1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(n, false); - } - - fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { - while rem > 0 || closed { - if closed { - SemState::fetch_set_closed(&self.state, AcqRel); - } - - // Release the permits and notify - self.add_permits_locked2(rem, closed); - - let n = rem << 1; - - let actual = if closed { - let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); - closed = false; - actual - } else { - let actual = self.rx_lock.fetch_sub(n, AcqRel); - closed = actual & 1 == 1; - actual - }; - - rem = (actual >> 1) - rem; - } - } - - /// Releases a specific amount of permits to the semaphore - /// - /// This function is called by `add_permits` after the add lock has been - /// acquired. - fn add_permits_locked2(&self, mut n: usize, closed: bool) { - // If closing the semaphore, we want to drain the entire queue. The - // number of permits being assigned doesn't matter. - if closed { - n = usize::MAX; - } - - 'outer: while n > 0 { - unsafe { - let mut head = self.head.with(|head| *head); - let mut next_ptr = head.as_ref().next.load(Acquire); - - let stub = self.stub(); - - if head == stub { - // The stub node indicates an empty queue. Any remaining - // permits get assigned back to the semaphore. - let next = match NonNull::new(next_ptr) { - Some(next) => next, - None => { - // This loop is not part of the standard intrusive mpsc - // channel algorithm. This is where we atomically pop - // the last task and add `n` to the remaining capacity. - // - // This modification to the pop algorithm works because, - // at this point, we have not done any work (only done - // reading). We have a *pretty* good idea that there is - // no concurrent pusher. - // - // The capacity is then atomically added by doing an - // AcqRel CAS on `state`. The `state` cell is the - // linchpin of the algorithm. - // - // By successfully CASing `head` w/ AcqRel, we ensure - // that, if any thread was racing and entered a push, we - // see that and abort pop, retrying as it is - // "inconsistent". - let mut curr = SemState::load(&self.state, Acquire); - - loop { - if curr.has_waiter(&self.stub) { - // A waiter is being added concurrently. - // This is the MPSC queue's "inconsistent" - // state and we must loop and try again. - thread::yield_now(); - continue 'outer; - } - - // If closing, nothing more to do. - if closed { - debug_assert!(curr.is_closed(), "state = {:?}", curr); - return; - } - - let mut next = curr; - next.release_permits(n, &self.stub); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = SemState(actual); - } - } - } - } - }; - - self.head.with_mut(|head| *head = next); - head = next; - next_ptr = next.as_ref().next.load(Acquire); - } - - // `head` points to a waiter assign permits to the waiter. If - // all requested permits are satisfied, then we can continue, - // otherwise the node stays in the wait queue. - if !head.as_ref().assign_permits(&mut n, closed) { - assert_eq!(n, 0); - return; - } - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - let state = SemState::load(&self.state, Acquire); - - // This must always be a pointer as the wait list is not empty. - let tail = state.waiter().unwrap(); - - if tail != head { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - self.push_stub(closed); - - next_ptr = head.as_ref().next.load(Acquire); - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - self.remove_queued(head, closed); - continue 'outer; - } - - // Inconsistent state, loop - thread::yield_now(); - } - } - } - - /// The wait node has had all of its permits assigned and has been removed - /// from the wait queue. - /// - /// Attempt to remove the QUEUED bit from the node. If additional permits - /// are concurrently requested, the node must be pushed back into the wait - /// queued. - fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) { - let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire)); - - loop { - if curr.is_dropped() { - // The Permit dropped, it is on us to release the memory - let _ = unsafe { Box::from_raw(waiter.as_ptr()) }; - return; - } - - // The node is removed from the queue. We attempt to unset the - // queued bit, but concurrently the waiter has requested more - // permits. When the waiter requested more permits, it saw the - // queued bit set so took no further action. This requires us to - // push the node back into the queue. - if curr.permits_to_acquire() > 0 { - // More permits are requested. The waiter must be re-queued - unsafe { - self.push_waiter(waiter, closed); - } - return; - } - - let mut next = curr; - next.unset_queued(); - - let w = unsafe { waiter.as_ref() }; - - match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return, - Err(actual) => { - curr = WaiterState(actual); - } - } - } - } - - unsafe fn push_stub(&self, closed: bool) { - self.push_waiter(self.stub(), closed); - } - - unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) { - // Set the next pointer. This does not require an atomic operation as - // this node is not accessible. The write will be flushed with the next - // operation - waiter.as_ref().next.store(ptr::null_mut(), Relaxed); - - // Update the tail to point to the new node. We need to see the previous - // node in order to update the next pointer as well as release `task` - // to any other threads calling `push`. - let next = SemState::new_ptr(waiter, closed); - let prev = SemState(self.state.swap(next.0, AcqRel)); - - debug_assert_eq!(closed, prev.is_closed()); - - // This function is only called when there are pending tasks. Because of - // this, the state must *always* be in pointer mode. - let prev = prev.waiter().unwrap(); - - // No cycles plz - debug_assert_ne!(prev, waiter); - - // Release `task` to the consume end. - prev.as_ref().next.store(waiter.as_ptr(), Release); - } - - fn stub(&self) -> NonNull<Waiter> { - unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } - } -} - -impl Drop for Semaphore { - fn drop(&mut self) { - self.close(); - } -} - -impl fmt::Debug for Semaphore { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Semaphore") - .field("state", &SemState::load(&self.state, Relaxed)) - .field("head", &self.head.with(|ptr| ptr)) - .field("rx_lock", &self.rx_lock.load(Relaxed)) - .field("stub", &self.stub) - .finish() - } -} - -unsafe impl Send for Semaphore {} -unsafe impl Sync for Semaphore {} - -// ===== impl Permit ===== - -impl Permit { - /// Creates a new `Permit`. - /// - /// The permit begins in the "unacquired" state. - pub(crate) fn new() -> Permit { - use PermitState::Acquired; - - Permit { - waiter: None, - state: Acquired(0), - } - } - - /// Returns `true` if the permit has been acquired - #[allow(dead_code)] // may be used later - pub(crate) fn is_acquired(&self) -> bool { - match self.state { - PermitState::Acquired(num) if num > 0 => true, - _ => false, - } - } - - /// Tries to acquire the permit. If no permits are available, the current task - /// is notified once a new permit becomes available. - pub(crate) fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - num_permits: u16, - semaphore: &Semaphore, - ) -> Poll<Result<(), AcquireError>> { - use std::cmp::Ordering::*; - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - match requested.cmp(&num_permits) { - Less => { - let delta = num_permits - requested; - - // Request additional permits. If the waiter has been - // dequeued, it must be re-queued. - if !waiter.try_inc_permits_to_acquire(delta as usize) { - let waiter = NonNull::from(&**waiter); - - // Ignore the result. The check for - // `permits_to_acquire()` will converge the state as - // needed - let _ = semaphore.poll_acquire2(delta, || Some(waiter))?; - } - - self.state = Waiting(num_permits); - } - Greater => { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - Equal => {} - } - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - waiter.waker.register_by_ref(cx.waker()); - - if waiter.permits_to_acquire()? == 0 { - self.state = Acquired(requested); - return Ready(Ok(())); - } - - Pending - } - Acquired(acquired) => { - if acquired >= num_permits { - Ready(Ok(())) - } else { - match semaphore.poll_acquire(cx, num_permits - acquired, self)? { - Ready(()) => { - self.state = Acquired(num_permits); - Ready(Ok(())) - } - Pending => { - self.state = Waiting(num_permits); - Pending - } - } - } - } - } - } - - /// Tries to acquire the permit. - pub(crate) fn try_acquire( - &mut self, - num_permits: u16, - semaphore: &Semaphore, - ) -> Result<(), TryAcquireError> { - use PermitState::*; - - match self.state { - Waiting(requested) => { - // There must be a waiter - let waiter = self.waiter.as_ref().unwrap(); - - if requested > num_permits { - let delta = requested - num_permits; - let to_release = waiter.try_dec_permits_to_acquire(delta as usize); - - semaphore.add_permits(to_release); - self.state = Waiting(num_permits); - } - - let res = waiter.permits_to_acquire().map_err(to_try_acquire)?; - - if res == 0 { - if requested < num_permits { - // Try to acquire the additional permits - semaphore.try_acquire(num_permits - requested)?; - } - - self.state = Acquired(num_permits); - Ok(()) - } else { - Err(TryAcquireError::NoPermits) - } - } - Acquired(acquired) => { - if acquired < num_permits { - semaphore.try_acquire(num_permits - acquired)?; - self.state = Acquired(num_permits); - } - - Ok(()) - } - } - } - - /// Releases a permit back to the semaphore - pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) { - let n = self.forget(n); - semaphore.add_permits(n as usize); - } - - /// Forgets the permit **without** releasing it back to the semaphore. - /// - /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the semaphore. - /// - /// Repeatedly calling `forget` without associated calls to `add_permit` - /// will result in the semaphore losing all permits. - /// - /// Will forget **at most** the number of acquired permits. This number is - /// returned. - pub(crate) fn forget(&mut self, n: u16) -> u16 { - use PermitState::*; - - match self.state { - Waiting(requested) => { - let n = cmp::min(n, requested); - - // Decrement - let acquired = self - .waiter - .as_ref() - .unwrap() - .try_dec_permits_to_acquire(n as usize) as u16; - - if n == requested { - self.state = Acquired(0); - } else if acquired == requested - n { - self.state = Waiting(acquired); - } else { - self.state = Waiting(requested - n); - } - - acquired - } - Acquired(acquired) => { - let n = cmp::min(n, acquired); - self.state = Acquired(acquired - n); - n - } - } - } -} - -impl Default for Permit { - fn default() -> Self { - Self::new() - } -} - -impl Drop for Permit { - fn drop(&mut self) { - if let Some(waiter) = self.waiter.take() { - // Set the dropped flag - let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel)); - - if state.is_queued() { - // The waiter is stored in the queue. The semaphore will drop it - std::mem::forget(waiter); - } - } - } -} - -// ===== impl AcquireError ==== - -impl AcquireError { - fn closed() -> AcquireError { - AcquireError(()) - } -} - -fn to_try_acquire(_: AcquireError) -> TryAcquireError { - TryAcquireError::Closed -} - -impl fmt::Display for AcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "semaphore closed") - } -} - -impl std::error::Error for AcquireError {} - -// ===== impl TryAcquireError ===== - -impl TryAcquireError { - /// Returns `true` if the error was caused by a closed semaphore. - pub(crate) fn is_closed(&self) -> bool { - match self { - TryAcquireError::Closed => true, - _ => false, - } - } - - /// Returns `true` if the error was caused by calling `try_acquire` on a - /// semaphore with no available permits. - pub(crate) fn is_no_permits(&self) -> bool { - match self { - TryAcquireError::NoPermits => true, - _ => false, - } - } -} - -impl fmt::Display for TryAcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TryAcquireError::Closed => write!(fmt, "semaphore closed"), - TryAcquireError::NoPermits => write!(fmt, "no permits available"), - } - } -} - -impl std::error::Error for TryAcquireError {} - -// ===== impl Waiter ===== - -impl Waiter { - fn new() -> Waiter { - Waiter { - state: AtomicUsize::new(0), - waker: AtomicWaker::new(), - next: AtomicPtr::new(ptr::null_mut()), - } - } - - fn permits_to_acquire(&self) -> Result<usize, AcquireError> { - let state = WaiterState(self.state.load(Acquire)); - - if state.is_closed() { - Err(AcquireError(())) - } else { - Ok(state.permits_to_acquire()) - } - } - - /// Only increments the number of permits *if* the waiter is currently - /// queued. - /// - /// # Returns - /// - /// `true` if the number of permits to acquire has been incremented. `false` - /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`. - fn try_inc_permits_to_acquire(&self, n: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - return false; - } - - let mut next = curr; - next.set_permits_to_acquire(n + curr.permits_to_acquire()); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return true, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Try to decrement the number of permits to acquire. This returns the - /// actual number of permits that were decremented. The delta betweeen `n` - /// and the return has been assigned to the permit and the caller must - /// assign these back to the semaphore. - fn try_dec_permits_to_acquire(&self, n: usize) -> usize { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - if !curr.is_queued() { - assert_eq!(0, curr.permits_to_acquire()); - } - - let delta = cmp::min(n, curr.permits_to_acquire()); - let rem = curr.permits_to_acquire() - delta; - - let mut next = curr; - next.set_permits_to_acquire(rem); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => return n - delta, - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Store the number of remaining permits needed to satisfy the waiter and - /// transition to the "QUEUED" state. - /// - /// # Returns - /// - /// `true` if the `QUEUED` bit was set as part of the transition. - fn to_queued(&self, num_permits: usize) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - // The waiter should **not** be waiting for any permits. - debug_assert_eq!(curr.permits_to_acquire(), 0); - - loop { - let mut next = curr; - next.set_permits_to_acquire(num_permits); - next.set_queued(); - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - if curr.is_queued() { - return false; - } else { - // Make sure the next pointer is null - self.next.store(ptr::null_mut(), Relaxed); - return true; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - /// Set the number of permits to acquire. - /// - /// This function is only called when the waiter is being inserted into the - /// wait queue. Because of this, there are no concurrent threads that can - /// modify the state and using `store` is safe. - fn set_permits_to_acquire(&self, num_permits: usize) { - debug_assert!(WaiterState(self.state.load(Acquire)).is_queued()); - - let mut state = WaiterState(QUEUED); - state.set_permits_to_acquire(num_permits); - - self.state.store(state.0, Release); - } - - /// Assign permits to the waiter. - /// - /// Returns `true` if the waiter should be removed from the queue - fn assign_permits(&self, n: &mut usize, closed: bool) -> bool { - let mut curr = WaiterState(self.state.load(Acquire)); - - loop { - let mut next = curr; - - // Number of permits to assign to this waiter - let assign = cmp::min(curr.permits_to_acquire(), *n); - - // Assign the permits - next.set_permits_to_acquire(curr.permits_to_acquire() - assign); - - if closed { - next.set_closed(); - } - - match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) { - Ok(_) => { - // Update `n` - *n -= assign; - - if next.permits_to_acquire() == 0 { - if curr.permits_to_acquire() > 0 { - self.waker.wake(); - } - - return true; - } else { - return false; - } - } - Err(actual) => curr = WaiterState(actual), - } - } - } - - fn revert_to_idle(&self) { - // An idle node is not waiting on any permits - self.state.store(0, Relaxed); - } - - fn store_next(&self, next: NonNull<Waiter>) { - self.next.store(next.as_ptr(), Release); - } -} - -// ===== impl SemState ===== - -impl SemState { - /// Returns a new default `State` value. - fn new(permits: usize, stub: &Waiter) -> SemState { - assert!(permits <= MAX_PERMITS); - - if permits > 0 { - SemState((permits << NUM_SHIFT) | NUM_FLAG) - } else { - SemState(stub as *const _ as usize) - } - } - - /// Returns a `State` tracking `ptr` as the tail of the queue. - fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState { - let mut val = tail.as_ptr() as usize; - - if closed { - val |= CLOSED_FLAG; - } - - SemState(val) - } - - /// Returns the amount of remaining capacity - fn available_permits(self) -> usize { - if !self.has_available_permits() { - return 0; - } - - self.0 >> NUM_SHIFT - } - - /// Returns `true` if the state has permits that can be claimed by a waiter. - fn has_available_permits(self) -> bool { - self.0 & NUM_FLAG == NUM_FLAG - } - - fn has_waiter(self, stub: &Waiter) -> bool { - !self.has_available_permits() && !self.is_stub(stub) - } - - /// Tries to atomically acquire specified number of permits. - /// - /// # Return - /// - /// Returns `true` if the specified number of permits were acquired, `false` - /// otherwise. Returning false does not mean that there are no more - /// available permits. - fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool { - debug_assert!(num > 0); - - if self.available_permits() < num { - return false; - } - - debug_assert!(self.waiter().is_none()); - - self.0 -= num << NUM_SHIFT; - - if self.0 == NUM_FLAG { - // Set the state to the stub pointer. - self.0 = stub as *const _ as usize; - } - - true - } - - /// Releases permits - /// - /// Returns `true` if the permits were accepted. - fn release_permits(&mut self, permits: usize, stub: &Waiter) { - debug_assert!(permits > 0); - - if self.is_stub(stub) { - self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG); - return; - } - - debug_assert!(self.has_available_permits()); - - self.0 += permits << NUM_SHIFT; - } - - fn is_waiter(self) -> bool { - self.0 & NUM_FLAG == 0 - } - - /// Returns the waiter, if one is set. - fn waiter(self) -> Option<NonNull<Waiter>> { - if self.is_waiter() { - let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored"); - - Some(waiter) - } else { - None - } - } - - /// Assumes `self` represents a pointer - fn as_ptr(self) -> *mut Waiter { - (self.0 & !CLOSED_FLAG) as *mut Waiter - } - - /// Sets to a pointer to a waiter. - /// - /// This can only be done from the full state. - fn set_waiter(&mut self, waiter: NonNull<Waiter>) { - let waiter = waiter.as_ptr() as usize; - debug_assert!(!self.is_closed()); - - self.0 = waiter; - } - - fn is_stub(self, stub: &Waiter) -> bool { - self.as_ptr() as usize == stub as *const _ as usize - } - - /// Loads the state from an AtomicUsize. - fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.load(ordering); - SemState(value) - } - - fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.fetch_or(CLOSED_FLAG, ordering); - SemState(value) - } - - fn is_closed(self) -> bool { - self.0 & CLOSED_FLAG == CLOSED_FLAG - } - - /// Converts the state into a `usize` representation. - fn to_usize(self) -> usize { - self.0 - } -} - -impl fmt::Debug for SemState { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fmt = fmt.debug_struct("SemState"); - - if self.is_waiter() { - fmt.field("state", &"<waiter>"); - } else { - fmt.field("permits", &self.available_permits()); - } - - fmt.finish() - } -} - -// ===== impl WaiterState ===== - -impl WaiterState { - fn permits_to_acquire(self) -> usize { - self.0 >> PERMIT_SHIFT - } - - fn set_permits_to_acquire(&mut self, val: usize) { - self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK) - } - - fn is_queued(self) -> bool { - self.0 & QUEUED == QUEUED - } - - fn set_queued(&mut self) { - self.0 |= QUEUED; - } - - fn is_closed(self) -> bool { - self.0 & CLOSED == CLOSED - } - - fn set_closed(&mut self) { - self.0 |= CLOSED; - } - - fn unset_queued(&mut self) { - assert!(self.is_queued()); - self.0 -= QUEUED; - } - - fn is_dropped(self) -> bool { - self.0 & DROPPED == DROPPED - } -} diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs index 73b1745..ae4cac7 100644 --- a/src/sync/task/atomic_waker.rs +++ b/src/sync/task/atomic_waker.rs @@ -141,13 +141,12 @@ impl AtomicWaker { } } + /* /// Registers the current waker to be notified on calls to `wake`. - /// - /// This is the same as calling `register_task` with `task::current()`. - #[cfg(feature = "io-driver")] pub(crate) fn register(&self, waker: Waker) { self.do_register(waker); } + */ /// Registers the provided waker to be notified on calls to `wake`. /// diff --git a/src/sync/tests/loom_broadcast.rs b/src/sync/tests/loom_broadcast.rs index da12fb9..4b1f034 100644 --- a/src/sync/tests/loom_broadcast.rs +++ b/src/sync/tests/loom_broadcast.rs @@ -1,5 +1,5 @@ use crate::sync::broadcast; -use crate::sync::broadcast::RecvError::{Closed, Lagged}; +use crate::sync::broadcast::error::RecvError::{Closed, Lagged}; use loom::future::block_on; use loom::sync::Arc; diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs deleted file mode 100644 index e9c9f3d..0000000 --- a/src/sync/tests/loom_cancellation_token.rs +++ /dev/null @@ -1,155 +0,0 @@ -use crate::sync::CancellationToken; - -use loom::{future::block_on, thread}; -use tokio_test::assert_ok; - -#[test] -fn cancel_token() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - - let th1 = thread::spawn(move || { - block_on(async { - token1.cancelled().await; - }); - }); - - let th2 = thread::spawn(move || { - token.cancel(); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - }); -} - -#[test] -fn cancel_with_child() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - let token2 = token.clone(); - let child_token = token.child_token(); - - let th1 = thread::spawn(move || { - block_on(async { - token1.cancelled().await; - }); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - block_on(async { - child_token.cancelled().await; - }); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_token_no_child() { - loom::model(|| { - let token = CancellationToken::new(); - let token1 = token.clone(); - let token2 = token.clone(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - drop(token2); - }); - - let th3 = thread::spawn(move || { - drop(token); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_token_with_childs() { - loom::model(|| { - let token1 = CancellationToken::new(); - let child_token1 = token1.child_token(); - let child_token2 = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - drop(child_token1); - }); - - let th3 = thread::spawn(move || { - drop(child_token2); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn drop_and_cancel_token() { - loom::model(|| { - let token1 = CancellationToken::new(); - let token2 = token1.clone(); - let child_token = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - drop(child_token); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} - -#[test] -fn cancel_parent_and_child() { - loom::model(|| { - let token1 = CancellationToken::new(); - let token2 = token1.clone(); - let child_token = token1.child_token(); - - let th1 = thread::spawn(move || { - drop(token1); - }); - - let th2 = thread::spawn(move || { - token2.cancel(); - }); - - let th3 = thread::spawn(move || { - child_token.cancel(); - }); - - assert_ok!(th1.join()); - assert_ok!(th2.join()); - assert_ok!(th3.join()); - }); -} diff --git a/src/sync/tests/loom_mpsc.rs b/src/sync/tests/loom_mpsc.rs index 6a1a6ab..c12313b 100644 --- a/src/sync/tests/loom_mpsc.rs +++ b/src/sync/tests/loom_mpsc.rs @@ -2,22 +2,24 @@ use crate::sync::mpsc; use futures::future::poll_fn; use loom::future::block_on; +use loom::sync::Arc; use loom::thread; +use tokio_test::assert_ok; #[test] fn closing_tx() { loom::model(|| { - let (mut tx, mut rx) = mpsc::channel(16); + let (tx, mut rx) = mpsc::channel(16); thread::spawn(move || { tx.try_send(()).unwrap(); drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -32,15 +34,70 @@ fn closing_unbounded_tx() { drop(tx); }); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_some()); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } #[test] +fn closing_bounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::channel::<()>(16); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] +fn closing_and_sending() { + loom::model(|| { + let (tx1, mut rx) = mpsc::channel::<()>(16); + let tx1 = Arc::new(tx1); + let tx2 = tx1.clone(); + + let th1 = thread::spawn(move || { + tx1.try_send(()).unwrap(); + }); + + let th2 = thread::spawn(move || { + block_on(tx2.closed()); + }); + + let th3 = thread::spawn(move || { + let v = block_on(rx.recv()); + assert!(v.is_some()); + drop(rx); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn closing_unbounded_rx() { + loom::model(|| { + let (tx1, rx) = mpsc::unbounded_channel::<()>(); + let tx2 = tx1.clone(); + thread::spawn(move || { + drop(rx); + }); + + block_on(tx1.closed()); + block_on(tx2.closed()); + }); +} + +#[test] fn dropping_tx() { loom::model(|| { let (tx, mut rx) = mpsc::channel::<()>(16); @@ -53,7 +110,7 @@ fn dropping_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } @@ -71,7 +128,7 @@ fn dropping_unbounded_tx() { } drop(tx); - let v = block_on(poll_fn(|cx| rx.poll_recv(cx))); + let v = block_on(rx.recv()); assert!(v.is_none()); }); } diff --git a/src/sync/tests/loom_notify.rs b/src/sync/tests/loom_notify.rs index 60981d4..79a5bf8 100644 --- a/src/sync/tests/loom_notify.rs +++ b/src/sync/tests/loom_notify.rs @@ -16,7 +16,7 @@ fn notify_one() { }); }); - tx.notify(); + tx.notify_one(); th.join().unwrap(); }); } @@ -34,12 +34,12 @@ fn notify_multi() { ths.push(thread::spawn(move || { block_on(async { notify.notified().await; - notify.notify(); + notify.notify_one(); }) })); } - notify.notify(); + notify.notify_one(); for th in ths.drain(..) { th.join().unwrap(); @@ -67,7 +67,7 @@ fn notify_drop() { block_on(poll_fn(|cx| { if recv.as_mut().poll(cx).is_ready() { - rx1.notify(); + rx1.notify_one(); } Poll::Ready(()) })); @@ -77,12 +77,12 @@ fn notify_drop() { block_on(async { rx2.notified().await; // Trigger second notification - rx2.notify(); + rx2.notify_one(); rx2.notified().await; }); }); - notify.notify(); + notify.notify_one(); th1.join().unwrap(); th2.join().unwrap(); diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs index dfa7459..9729cfb 100644 --- a/src/sync/tests/loom_oneshot.rs +++ b/src/sync/tests/loom_oneshot.rs @@ -75,8 +75,10 @@ impl Future for OnClose<'_> { type Output = bool; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> { - let res = self.get_mut().tx.poll_closed(cx); - Ready(res.is_ready()) + let fut = self.get_mut().tx.closed(); + crate::pin!(fut); + + Ready(fut.poll(cx).is_ready()) } } diff --git a/src/sync/tests/loom_semaphore_ll.rs b/src/sync/tests/loom_semaphore_ll.rs deleted file mode 100644 index b5e5efb..0000000 --- a/src/sync/tests/loom_semaphore_ll.rs +++ /dev/null @@ -1,192 +0,0 @@ -use crate::sync::semaphore_ll::*; - -use futures::future::poll_fn; -use loom::future::block_on; -use loom::thread; -use std::future::Future; -use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::Arc; -use std::task::Poll::Ready; -use std::task::{Context, Poll}; - -#[test] -fn basic_usage() { - const NUM: usize = 2; - - struct Actor { - waiter: Permit, - shared: Arc<Shared>, - } - - struct Shared { - semaphore: Semaphore, - active: AtomicUsize, - } - - impl Future for Actor { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let me = &mut *self; - - ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap(); - - let actual = me.shared.active.fetch_add(1, SeqCst); - assert!(actual <= NUM - 1); - - let actual = me.shared.active.fetch_sub(1, SeqCst); - assert!(actual <= NUM); - - me.waiter.release(1, &me.shared.semaphore); - - Ready(()) - } - } - - loom::model(|| { - let shared = Arc::new(Shared { - semaphore: Semaphore::new(NUM), - active: AtomicUsize::new(0), - }); - - for _ in 0..NUM { - let shared = shared.clone(); - - thread::spawn(move || { - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); - } - - block_on(Actor { - waiter: Permit::new(), - shared, - }); - }); -} - -#[test] -fn release() { - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - { - let semaphore = semaphore.clone(); - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); - } - - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap(); - - permit.release(1, &semaphore); - }); -} - -#[test] -fn basic_closing() { - const NUM: usize = 2; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - for _ in 0..2 { - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - } - - Ok::<(), ()>(()) - }); - } - - semaphore.close(); - }); -} - -#[test] -fn concurrent_close() { - const NUM: usize = 3; - - loom::model(|| { - let semaphore = Arc::new(Semaphore::new(1)); - - for _ in 0..NUM { - let semaphore = semaphore.clone(); - - thread::spawn(move || { - let mut permit = Permit::new(); - - block_on(poll_fn(|cx| { - permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ()) - }))?; - - permit.release(1, &semaphore); - - semaphore.close(); - - Ok::<(), ()>(()) - }); - } - }); -} - -#[test] -fn batch() { - let mut b = loom::model::Builder::new(); - b.preemption_bound = Some(1); - - b.check(|| { - let semaphore = Arc::new(Semaphore::new(10)); - let active = Arc::new(AtomicUsize::new(0)); - let mut ths = vec![]; - - for _ in 0..2 { - let semaphore = semaphore.clone(); - let active = active.clone(); - - ths.push(thread::spawn(move || { - let mut permit = Permit::new(); - - for n in &[4, 10, 8] { - block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap(); - - active.fetch_add(*n as usize, SeqCst); - - let num_active = active.load(SeqCst); - assert!(num_active <= 10); - - thread::yield_now(); - - active.fetch_sub(*n as usize, SeqCst); - - permit.release(*n, &semaphore); - } - })); - } - - for th in ths.into_iter() { - th.join().unwrap(); - } - - assert_eq!(10, semaphore.available_permits()); - }); -} diff --git a/src/sync/tests/loom_watch.rs b/src/sync/tests/loom_watch.rs new file mode 100644 index 0000000..c575b5b --- /dev/null +++ b/src/sync/tests/loom_watch.rs @@ -0,0 +1,36 @@ +use crate::sync::watch; + +use loom::future::block_on; +use loom::thread; + +#[test] +fn smoke() { + loom::model(|| { + let (tx, mut rx1) = watch::channel(1); + let mut rx2 = rx1.clone(); + let mut rx3 = rx1.clone(); + let mut rx4 = rx1.clone(); + let mut rx5 = rx1.clone(); + + let th = thread::spawn(move || { + tx.send(2).unwrap(); + }); + + block_on(rx1.changed()).unwrap(); + assert_eq!(*rx1.borrow(), 2); + + block_on(rx2.changed()).unwrap(); + assert_eq!(*rx2.borrow(), 2); + + block_on(rx3.changed()).unwrap(); + assert_eq!(*rx3.borrow(), 2); + + block_on(rx4.changed()).unwrap(); + assert_eq!(*rx4.borrow(), 2); + + block_on(rx5.changed()).unwrap(); + assert_eq!(*rx5.borrow(), 2); + + th.join().unwrap(); + }) +} diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs index 6ba8c1f..a78be6f 100644 --- a/src/sync/tests/mod.rs +++ b/src/sync/tests/mod.rs @@ -1,18 +1,15 @@ cfg_not_loom! { mod atomic_waker; - mod semaphore_ll; mod semaphore_batch; } cfg_loom! { mod loom_atomic_waker; mod loom_broadcast; - #[cfg(tokio_unstable)] - mod loom_cancellation_token; mod loom_list; mod loom_mpsc; mod loom_notify; mod loom_oneshot; mod loom_semaphore_batch; - mod loom_semaphore_ll; + mod loom_watch; } diff --git a/src/sync/tests/semaphore_ll.rs b/src/sync/tests/semaphore_ll.rs deleted file mode 100644 index bfb0757..0000000 --- a/src/sync/tests/semaphore_ll.rs +++ /dev/null @@ -1,470 +0,0 @@ -use crate::sync::semaphore_ll::{Permit, Semaphore}; -use tokio_test::*; - -#[test] -fn poll_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - assert!(!permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling for a larger number of permits acquires more - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s))); - assert_eq!(s.available_permits(), 92); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_one_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 99); - assert!(permit.is_acquired()); -} - -#[test] -fn try_acquire_many_available() { - let s = Semaphore::new(100); - assert_eq!(s.available_permits(), 100); - - // Polling for a permit succeeds immediately - let mut permit = Permit::new(); - assert!(!permit.is_acquired()); - - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); - - // Polling again on the same waiter does not claim a new permit - assert_ok!(permit.try_acquire(5, &s)); - assert_eq!(s.available_permits(), 95); - assert!(permit.is_acquired()); -} - -#[test] -fn poll_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn forget_acquired() { - let s = Semaphore::new(1); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_waiting() { - let s = Semaphore::new(0); - - // Polling for a permit succeeds immediately - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); - - permit.forget(1); - - s.add_permits(1); - - assert!(!permit.is_woken()); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn poll_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - let mut permit_3 = task::spawn(Permit::new()); - - // Acquire the first permit - assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 4); - - permit_2.enter(|cx, mut p| { - // Try to acquire the second permit - assert_pending!(p.poll_acquire(cx, 5, &s)); - }); - - assert_eq!(s.available_permits(), 0); - - permit_3.enter(|cx, mut p| { - // Try to acquire the third permit - assert_pending!(p.poll_acquire(cx, 3, &s)); - }); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 0); - assert!(permit_2.is_woken()); - assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(1, &s); - assert!(!permit_3.is_woken()); - assert_eq!(s.available_permits(), 0); - - permit_2.release(2, &s); - assert!(permit_3.is_woken()); - - assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_one_unavailable() { - let s = Semaphore::new(1); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 0); - - assert_err!(permit_2.try_acquire(1, &s)); - - permit_1.release(1, &s); - - assert_eq!(s.available_permits(), 1); - assert_ok!(permit_2.try_acquire(1, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); -} - -#[test] -fn try_acquire_many_unavailable() { - let s = Semaphore::new(5); - - let mut permit_1 = Permit::new(); - let mut permit_2 = Permit::new(); - - // Acquire the first permit - assert_ok!(permit_1.try_acquire(1, &s)); - assert_eq!(s.available_permits(), 4); - - assert_err!(permit_2.try_acquire(5, &s)); - - permit_1.release(1, &s); - assert_eq!(s.available_permits(), 5); - - assert_ok!(permit_2.try_acquire(5, &s)); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 1); - - permit_2.release(1, &s); - assert_eq!(s.available_permits(), 2); -} - -#[test] -fn poll_acquire_one_zero_permits() { - let s = Semaphore::new(0); - assert_eq!(s.available_permits(), 0); - - let mut permit = task::spawn(Permit::new()); - - // Try to acquire the permit - permit.enter(|cx, mut p| { - assert_pending!(p.poll_acquire(cx, 1, &s)); - }); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -#[should_panic] -fn validates_max_permits() { - use std::usize; - Semaphore::new((usize::MAX >> 2) + 1); -} - -#[test] -fn close_semaphore_prevents_acquire() { - let s = Semaphore::new(5); - s.close(); - - assert_eq!(5, s.available_permits()); - - let mut permit_1 = task::spawn(Permit::new()); - let mut permit_2 = task::spawn(Permit::new()); - - assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(5, s.available_permits()); - - assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(5, s.available_permits()); -} - -#[test] -fn close_semaphore_notifies_permit1() { - let s = Semaphore::new(0); - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit.is_woken()); - assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn close_semaphore_notifies_permit2() { - let s = Semaphore::new(2); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - let mut permit3 = task::spawn(Permit::new()); - let mut permit4 = task::spawn(Permit::new()); - - // Acquire a couple of permits - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.close(); - - assert!(permit3.is_woken()); - assert!(permit4.is_woken()); - - assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(0, s.available_permits()); - - permit1.release(1, &s); - - assert_eq!(1, s.available_permits()); - - assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - permit2.release(1, &s); - - assert_eq!(2, s.available_permits()); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - s.add_permits(1); - assert!(!permit.is_woken()); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn try_acquire_additional_permits_while_waiting_before_assigned() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s))); - - s.add_permits(1); - assert!(permit.is_woken()); - - assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_success() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); -} - -#[test] -fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - - s.add_permits(2); - - assert!(permit.is_woken()); - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); - - s.add_permits(1); - - assert!(permit.is_woken()); - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s))); -} - -#[test] -fn poll_acquire_fewer_permits_while_waiting() { - let s = Semaphore::new(1); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn poll_acquire_fewer_permits_after_assigned() { - let s = Semaphore::new(1); - - let mut permit1 = task::spawn(Permit::new()); - let mut permit2 = task::spawn(Permit::new()); - - assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s))); - assert_eq!(s.available_permits(), 0); - - assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - s.add_permits(4); - assert!(permit1.is_woken()); - assert!(!permit2.is_woken()); - - assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s))); - - assert!(permit2.is_woken()); - assert_eq!(s.available_permits(), 1); - - assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); -} - -#[test] -fn forget_partial_1() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s))); - - assert_eq!(s.available_permits(), 0); -} - -#[test] -fn forget_partial_2() { - let s = Semaphore::new(0); - - let mut permit = task::spawn(Permit::new()); - - assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - s.add_permits(1); - - assert_eq!(0, s.available_permits()); - - permit.release(1, &s); - - s.add_permits(1); - - assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s))); - assert_eq!(s.available_permits(), 0); -} diff --git a/src/sync/watch.rs b/src/sync/watch.rs index 13033d9..ec73832 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -6,13 +6,11 @@ //! //! # Usage //! -//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are -//! the producer and sender halves of the channel. The channel is -//! created with an initial value. [`Receiver::recv`] will always -//! be ready upon creation and will yield either this initial value or -//! the latest value that has been sent by `Sender`. -//! -//! Calls to [`Receiver::recv`] will always yield the latest value. +//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer +//! and sender halves of the channel. The channel is created with an initial +//! value. The **latest** value stored in the channel is accessed with +//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new +//! value to sent by the [`Sender`] half. //! //! # Examples //! @@ -23,21 +21,21 @@ //! let (tx, mut rx) = watch::channel("hello"); //! //! tokio::spawn(async move { -//! while let Some(value) = rx.recv().await { -//! println!("received = {:?}", value); +//! while rx.changed().await.is_ok() { +//! println!("received = {:?}", *rx.borrow()); //! } //! }); //! -//! tx.broadcast("world")?; +//! tx.send("world")?; //! # Ok(()) //! # } //! ``` //! //! # Closing //! -//! [`Sender::closed`] allows the producer to detect when all [`Receiver`] -//! handles have been dropped. This indicates that there is no further interest -//! in the values being produced and work can be stopped. +//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect +//! when all [`Receiver`] handles have been dropped. This indicates that there +//! is no further interest in the values being produced and work can be stopped. //! //! # Thread safety //! @@ -47,20 +45,18 @@ //! //! [`Sender`]: crate::sync::watch::Sender //! [`Receiver`]: crate::sync::watch::Receiver -//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv +//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed +//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow //! [`channel`]: crate::sync::watch::channel +//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed //! [`Sender::closed`]: crate::sync::watch::Sender::closed -use crate::future::poll_fn; -use crate::sync::task::AtomicWaker; +use crate::sync::Notify; -use fnv::FnvHashSet; use std::ops; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; -use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak}; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; +use std::sync::{Arc, RwLock, RwLockReadGuard}; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -70,8 +66,8 @@ pub struct Receiver<T> { /// Pointer to the shared state shared: Arc<Shared<T>>, - /// Pointer to the watcher's internal state - inner: Watcher, + /// Last observed version + version: usize, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -79,7 +75,7 @@ pub struct Receiver<T> { /// Instances are created by the [`channel`](fn@channel) function. #[derive(Debug)] pub struct Sender<T> { - shared: Weak<Shared<T>>, + shared: Arc<Shared<T>>, } /// Returns a reference to the inner value @@ -92,6 +88,27 @@ pub struct Ref<'a, T> { inner: RwLockReadGuard<'a, T>, } +#[derive(Debug)] +struct Shared<T> { + /// The most recent value + value: RwLock<T>, + + /// The current version + /// + /// The lowest bit represents a "closed" state. The rest of the bits + /// represent the current version. + version: AtomicUsize, + + /// Tracks the number of `Receiver` instances + ref_count_rx: AtomicUsize, + + /// Notifies waiting receivers that the value changed. + notify_rx: Notify, + + /// Notifies any task listening for `Receiver` dropped events + notify_tx: Notify, +} + pub mod error { //! Watch error types @@ -112,37 +129,20 @@ pub mod error { } impl<T: fmt::Debug> std::error::Error for SendError<T> {} -} - -#[derive(Debug)] -struct Shared<T> { - /// The most recent value - value: RwLock<T>, - /// The current version - /// - /// The lowest bit represents a "closed" state. The rest of the bits - /// represent the current version. - version: AtomicUsize, - - /// All watchers - watchers: Mutex<Watchers>, - - /// Task to notify when all watchers drop - cancel: AtomicWaker, -} + /// Error produced when receiving a change notification. + #[derive(Debug)] + pub struct RecvError(pub(super) ()); -type Watchers = FnvHashSet<Watcher>; + // ===== impl RecvError ===== -/// The watcher's ID is based on the Arc's pointer. -#[derive(Clone, Debug)] -struct Watcher(Arc<WatchInner>); + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } -#[derive(Debug)] -struct WatchInner { - /// Last observed version - version: AtomicUsize, - waker: AtomicWaker, + impl std::error::Error for RecvError {} } const CLOSED: usize = 1; @@ -162,41 +162,32 @@ const CLOSED: usize = 1; /// let (tx, mut rx) = watch::channel("hello"); /// /// tokio::spawn(async move { -/// while let Some(value) = rx.recv().await { -/// println!("received = {:?}", value); +/// while rx.changed().await.is_ok() { +/// println!("received = {:?}", *rx.borrow()); /// } /// }); /// -/// tx.broadcast("world")?; +/// tx.send("world")?; /// # Ok(()) /// # } /// ``` /// /// [`Sender`]: struct@Sender /// [`Receiver`]: struct@Receiver -pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) { - const VERSION_0: usize = 0; - const VERSION_1: usize = 2; - - // We don't start knowing VERSION_1 - let inner = Watcher::new_version(VERSION_0); - - // Insert the watcher - let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default()); - watchers.insert(inner.clone()); - +pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(VERSION_1), - watchers: Mutex::new(watchers), - cancel: AtomicWaker::new(), + version: AtomicUsize::new(0), + ref_count_rx: AtomicUsize::new(1), + notify_rx: Notify::new(), + notify_tx: Notify::new(), }); let tx = Sender { - shared: Arc::downgrade(&shared), + shared: shared.clone(), }; - let rx = Receiver { shared, inner }; + let rx = Receiver { shared, version: 0 }; (tx, rx) } @@ -221,39 +212,13 @@ impl<T> Receiver<T> { Ref { inner } } - // TODO: document - #[doc(hidden)] - pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Option<Ref<'a, T>>> { - // Make sure the task is up to date - self.inner.waker.register_by_ref(cx.waker()); - - let state = self.shared.version.load(SeqCst); - let version = state & !CLOSED; - - if self.inner.version.swap(version, Relaxed) != version { - let inner = self.shared.value.read().unwrap(); - - return Ready(Some(Ref { inner })); - } - - if CLOSED == state & CLOSED { - // The `Store` handle has been dropped. - return Ready(None); - } - - Pending - } -} - -impl<T: Clone> Receiver<T> { - /// Attempts to clone the latest value sent via the channel. + /// Wait for a change notification /// - /// If this is the first time the function is called on a `Receiver` - /// instance, then the function completes immediately with the **current** - /// value held by the channel. On the next call, the function waits until - /// a new value is sent in the channel. + /// Returns when a new value has been sent by the [`Sender`] since the last + /// time `changed()` was called. When the `Sender` half is dropped, `Err` is + /// returned. /// - /// `None` is returned if the `Sender` half is dropped. + /// [`Sender`]: struct@Sender /// /// # Examples /// @@ -264,118 +229,170 @@ impl<T: Clone> Receiver<T> { /// async fn main() { /// let (tx, mut rx) = watch::channel("hello"); /// - /// let v = rx.recv().await.unwrap(); - /// assert_eq!(v, "hello"); - /// /// tokio::spawn(async move { - /// tx.broadcast("goodbye").unwrap(); + /// tx.send("goodbye").unwrap(); /// }); /// - /// // Waits for the new task to spawn and send the value. - /// let v = rx.recv().await.unwrap(); - /// assert_eq!(v, "goodbye"); + /// assert!(rx.changed().await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); /// - /// let v = rx.recv().await; - /// assert!(v.is_none()); + /// // The `tx` handle has been dropped + /// assert!(rx.changed().await.is_err()); /// } /// ``` - pub async fn recv(&mut self) -> Option<T> { - poll_fn(|cx| { - let v_ref = ready!(self.poll_recv_ref(cx)); - Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + pub async fn changed(&mut self) -> Result<(), error::RecvError> { + use std::future::Future; + use std::pin::Pin; + use std::task::Poll; + + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. Requesting the notification + // requires polling the future once. + let notified = self.shared.notify_rx.notified(); + pin!(notified); + + // Polling the future once is guaranteed to return `Pending` as `watch` + // only notifies using `notify_waiters`. + crate::future::poll_fn(|cx| { + let res = Pin::new(&mut notified).poll(cx); + assert!(!res.is_ready()); + Poll::Ready(()) }) - .await + .await; + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + + maybe_changed(&self.shared, &mut self.version) + .expect("[bug] failed to observe change after notificaton.") } } -#[cfg(feature = "stream")] -impl<T: Clone> crate::stream::Stream for Receiver<T> { - type Item = T; - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - let v_ref = ready!(self.poll_recv_ref(cx)); +fn maybe_changed<T>( + shared: &Shared<T>, + version: &mut usize, +) -> Option<Result<(), error::RecvError>> { + // Load the version from the state + let state = shared.version.load(SeqCst); + let new_version = state & !CLOSED; + + if *version != new_version { + // Observe the new version and return + *version = new_version; + return Some(Ok(())); + } - Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone())) + if CLOSED == state & CLOSED { + // All receivers have dropped. + return Some(Err(error::RecvError(()))); } + + None } impl<T> Clone for Receiver<T> { fn clone(&self) -> Self { - let ver = self.inner.version.load(Relaxed); - let inner = Watcher::new_version(ver); + let version = self.version; let shared = self.shared.clone(); - shared.watchers.lock().unwrap().insert(inner.clone()); + // No synchronization necessary as this is only used as a counter and + // not memory access. + shared.ref_count_rx.fetch_add(1, Relaxed); - Receiver { shared, inner } + Receiver { version, shared } } } impl<T> Drop for Receiver<T> { fn drop(&mut self) { - self.shared.watchers.lock().unwrap().remove(&self.inner); + // No synchronization necessary as this is only used as a counter and + // not memory access. + if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) { + // This is the last `Receiver` handle, tasks waiting on `Sender::closed()` + self.shared.notify_tx.notify_waiters(); + } } } impl<T> Sender<T> { - /// Broadcasts a new value via the channel, notifying all receivers. - pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> { - let shared = match self.shared.upgrade() { - Some(shared) => shared, - // All `Watch` handles have been canceled - None => return Err(error::SendError { inner: value }), - }; - - // Replace the value - { - let mut lock = shared.value.write().unwrap(); - *lock = value; + /// Sends a new value via the channel, notifying all receivers. + pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { + // This is pretty much only useful as a hint anyway, so synchronization isn't critical. + if 0 == self.shared.ref_count_rx.load(Relaxed) { + return Err(error::SendError { inner: value }); } + *self.shared.value.write().unwrap() = value; + // Update the version. 2 is used so that the CLOSED bit is not set. - shared.version.fetch_add(2, SeqCst); + self.shared.version.fetch_add(2, SeqCst); // Notify all watchers - notify_all(&*shared); + self.shared.notify_rx.notify_waiters(); Ok(()) } + /// Checks if the channel has been closed. This happens when all receivers + /// have dropped. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(()); + /// assert!(!tx.is_closed()); + /// + /// drop(rx); + /// assert!(tx.is_closed()); + /// ``` + pub fn is_closed(&self) -> bool { + self.shared.ref_count_rx.load(Relaxed) == 0 + } + /// Completes when all receivers have dropped. /// /// This allows the producer to get notified when interest in the produced /// values is canceled and immediately stop doing work. - pub async fn closed(&mut self) { - poll_fn(|cx| self.poll_close(cx)).await - } + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = watch::channel("hello"); + /// + /// tokio::spawn(async move { + /// // use `rx` + /// drop(rx); + /// }); + /// + /// // Waits for `rx` to drop + /// tx.closed().await; + /// println!("the `rx` handles dropped") + /// } + /// ``` + pub async fn closed(&self) { + let notified = self.shared.notify_tx.notified(); - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self.shared.upgrade() { - Some(shared) => { - shared.cancel.register_by_ref(cx.waker()); - Pending - } - None => Ready(()), + if self.shared.ref_count_rx.load(Relaxed) == 0 { + return; } - } -} - -/// Notifies all watchers of a change -fn notify_all<T>(shared: &Shared<T>) { - let watchers = shared.watchers.lock().unwrap(); - for watcher in watchers.iter() { - // Notify the task - watcher.waker.wake(); + notified.await; + debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); } } impl<T> Drop for Sender<T> { fn drop(&mut self) { - if let Some(shared) = self.shared.upgrade() { - shared.version.fetch_or(CLOSED, SeqCst); - notify_all(&*shared); - } + self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.notify_rx.notify_waiters(); } } @@ -388,44 +405,3 @@ impl<T> ops::Deref for Ref<'_, T> { self.inner.deref() } } - -// ===== impl Shared ===== - -impl<T> Drop for Shared<T> { - fn drop(&mut self) { - self.cancel.wake(); - } -} - -// ===== impl Watcher ===== - -impl Watcher { - fn new_version(version: usize) -> Self { - Watcher(Arc::new(WatchInner { - version: AtomicUsize::new(version), - waker: AtomicWaker::new(), - })) - } -} - -impl std::cmp::PartialEq for Watcher { - fn eq(&self, other: &Watcher) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } -} - -impl std::cmp::Eq for Watcher {} - -impl std::hash::Hash for Watcher { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - (&*self.0 as *const WatchInner).hash(state) - } -} - -impl std::ops::Deref for Watcher { - type Target = WatchInner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} diff --git a/src/task/blocking.rs b/src/task/blocking.rs index ed60f4c..fc6632b 100644 --- a/src/task/blocking.rs +++ b/src/task/blocking.rs @@ -1,6 +1,6 @@ use crate::task::JoinHandle; -cfg_rt_threaded! { +cfg_rt_multi_thread! { /// Runs the provided blocking function on the current thread without /// blocking the executor. /// @@ -17,7 +17,7 @@ cfg_rt_threaded! { /// using the [`join!`] macro. To avoid this issue, use [`spawn_blocking`] /// instead. /// - /// Note that this function can only be used on the [threaded scheduler]. + /// Note that this function can only be used when using the `multi_thread` runtime. /// /// Code running behind `block_in_place` cannot be cancelled. When you shut /// down the executor, it will wait indefinitely for all blocking operations @@ -27,7 +27,6 @@ cfg_rt_threaded! { /// returns. /// /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code - /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler /// [`spawn_blocking`]: fn@crate::task::spawn_blocking /// [`join!`]: macro@join /// [`thread::spawn`]: fn@std::thread::spawn @@ -44,7 +43,6 @@ cfg_rt_threaded! { /// }); /// # } /// ``` - #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))] pub fn block_in_place<F, R>(f: F) -> R where F: FnOnce() -> R, @@ -53,80 +51,63 @@ cfg_rt_threaded! { } } -cfg_blocking! { - /// Runs the provided closure on a thread where blocking is acceptable. - /// - /// In general, issuing a blocking call or performing a lot of compute in a - /// future without yielding is not okay, as it may prevent the executor from - /// driving other futures forward. This function runs the provided closure - /// on a thread dedicated to blocking operations. See the [CPU-bound tasks - /// and blocking code][blocking] section for more information. - /// - /// Tokio will spawn more blocking threads when they are requested through - /// this function until the upper limit configured on the [`Builder`] is - /// reached. This limit is very large by default, because `spawn_blocking` is - /// often used for various kinds of IO operations that cannot be performed - /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you - /// should keep this large upper limit in mind; to run your CPU-bound - /// computations on only a few threads, you should use a separate thread - /// pool such as [rayon] rather than configuring the number of blocking - /// threads. - /// - /// This function is intended for non-async operations that eventually - /// finish on their own. If you want to spawn an ordinary thread, you should - /// use [`thread::spawn`] instead. - /// - /// Closures spawned using `spawn_blocking` cannot be cancelled. When you - /// shut down the executor, it will wait indefinitely for all blocking - /// operations to finish. You can use [`shutdown_timeout`] to stop waiting - /// for them after a certain timeout. Be aware that this will still not - /// cancel the tasks — they are simply allowed to keep running after the - /// method returns. - /// - /// Note that if you are using the [basic scheduler], this function will - /// still spawn additional threads for blocking operations. The basic - /// scheduler's single thread is only used for asynchronous code. - /// - /// [`Builder`]: struct@crate::runtime::Builder - /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code - /// [rayon]: https://docs.rs/rayon - /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler - /// [`thread::spawn`]: fn@std::thread::spawn - /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout - /// - /// # Examples - /// - /// ``` - /// use tokio::task; - /// - /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ - /// let res = task::spawn_blocking(move || { - /// // do some compute-heavy work or call synchronous code - /// "done computing" - /// }).await?; - /// - /// assert_eq!(res, "done computing"); - /// # Ok(()) - /// # } - /// ``` - pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - #[cfg(feature = "tracing")] - let f = { - let span = tracing::trace_span!( - target: "tokio::task", - "task", - kind = %"blocking", - function = %std::any::type_name::<F>(), - ); - move || { - let _g = span.enter(); - f() - } - }; - crate::runtime::spawn_blocking(f) - } +/// Runs the provided closure on a thread where blocking is acceptable. +/// +/// In general, issuing a blocking call or performing a lot of compute in a +/// future without yielding is problematic, as it may prevent the executor from +/// driving other futures forward. This function runs the provided closure on a +/// thread dedicated to blocking operations. See the [CPU-bound tasks and +/// blocking code][blocking] section for more information. +/// +/// Tokio will spawn more blocking threads when they are requested through this +/// function until the upper limit configured on the [`Builder`] is reached. +/// This limit is very large by default, because `spawn_blocking` is often used +/// for various kinds of IO operations that cannot be performed asynchronously. +/// When you run CPU-bound code using `spawn_blocking`, you should keep this +/// large upper limit in mind. When running many CPU-bound computations, a +/// semaphore or some other synchronization primitive should be used to limit +/// the number of computation executed in parallel. Specialized CPU-bound +/// executors, such as [rayon], may also be a good fit. +/// +/// This function is intended for non-async operations that eventually finish on +/// their own. If you want to spawn an ordinary thread, you should use +/// [`thread::spawn`] instead. +/// +/// Closures spawned using `spawn_blocking` cannot be cancelled. When you shut +/// down the executor, it will wait indefinitely for all blocking operations to +/// finish. You can use [`shutdown_timeout`] to stop waiting for them after a +/// certain timeout. Be aware that this will still not cancel the tasks — they +/// are simply allowed to keep running after the method returns. +/// +/// Note that if you are using the single threaded runtime, this function will +/// still spawn additional threads for blocking operations. The basic +/// scheduler's single thread is only used for asynchronous code. +/// +/// [`Builder`]: struct@crate::runtime::Builder +/// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code +/// [rayon]: https://docs.rs/rayon +/// [`thread::spawn`]: fn@std::thread::spawn +/// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout +/// +/// # Examples +/// +/// ``` +/// use tokio::task; +/// +/// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ +/// let res = task::spawn_blocking(move || { +/// // do some compute-heavy work or call synchronous code +/// "done computing" +/// }).await?; +/// +/// assert_eq!(res, "done computing"); +/// # Ok(()) +/// # } +/// ``` +pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + crate::runtime::spawn_blocking(f) } diff --git a/src/task/local.rs b/src/task/local.rs index 3c409ed..5896126 100644 --- a/src/task/local.rs +++ b/src/task/local.rs @@ -1,7 +1,7 @@ //! Runs `!Send` futures on the current thread. use crate::runtime::task::{self, JoinHandle, Task}; use crate::sync::AtomicWaker; -use crate::util::linked_list::LinkedList; +use crate::util::linked_list::{Link, LinkedList}; use std::cell::{Cell, RefCell}; use std::collections::VecDeque; @@ -14,7 +14,7 @@ use std::task::Poll; use pin_project_lite::pin_project; -cfg_rt_util! { +cfg_rt! { /// A set of tasks which are executed on the same thread. /// /// In some cases, it is necessary to run one or more futures that do not @@ -95,7 +95,7 @@ cfg_rt_util! { /// }); /// /// local.spawn_local(async move { - /// time::delay_for(time::Duration::from_millis(100)).await; + /// time::sleep(time::Duration::from_millis(100)).await; /// println!("goodbye {}", unsend_data) /// }); /// @@ -132,7 +132,7 @@ struct Context { struct Tasks { /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>>, + owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, /// Local run queue sender and receiver. queue: VecDeque<task::Notified<Arc<Shared>>>, @@ -158,7 +158,7 @@ pin_project! { scoped_thread_local!(static CURRENT: Context); -cfg_rt_util! { +cfg_rt! { /// Spawns a `!Send` future on the local task set. /// /// The spawned future will be run on the same thread that called `spawn_local.` @@ -312,9 +312,9 @@ impl LocalSet { /// use tokio::runtime::Runtime; /// use tokio::task; /// - /// let mut rt = Runtime::new().unwrap(); + /// let rt = Runtime::new().unwrap(); /// let local = task::LocalSet::new(); - /// local.block_on(&mut rt, async { + /// local.block_on(&rt, async { /// let join = task::spawn_local(async { /// let blocking_result = task::block_in_place(|| { /// // ... @@ -329,9 +329,9 @@ impl LocalSet { /// use tokio::runtime::Runtime; /// use tokio::task; /// - /// let mut rt = Runtime::new().unwrap(); + /// let rt = Runtime::new().unwrap(); /// let local = task::LocalSet::new(); - /// local.block_on(&mut rt, async { + /// local.block_on(&rt, async { /// let join = task::spawn_local(async { /// let blocking_result = task::spawn_blocking(|| { /// // ... @@ -346,7 +346,9 @@ impl LocalSet { /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on /// [in-place blocking]: fn@crate::task::block_in_place /// [`spawn_blocking`]: fn@crate::task::spawn_blocking - pub fn block_on<F>(&self, rt: &mut crate::runtime::Runtime, future: F) -> F::Output + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + pub fn block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output where F: Future, { diff --git a/src/task/mod.rs b/src/task/mod.rs index 5c89393..5dc5e72 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -102,7 +102,7 @@ //! # } //! ``` //! -//! `spawn`, `JoinHandle`, and `JoinError` are present when the "rt-core" +//! `spawn`, `JoinHandle`, and `JoinError` are present when the "rt" //! feature flag is enabled. //! //! [`task::spawn`]: crate::task::spawn() @@ -159,7 +159,7 @@ //! //! #### block_in_place //! -//! When using the [threaded runtime][rt-threaded], the [`task::block_in_place`] +//! When using the [multi-threaded runtime][rt-multi-thread], the [`task::block_in_place`] //! function is also available. Like `task::spawn_blocking`, this function //! allows running a blocking operation from an asynchronous context. Unlike //! `spawn_blocking`, however, `block_in_place` works by transitioning the @@ -211,29 +211,26 @@ //! //! [`task::spawn_blocking`]: crate::task::spawn_blocking //! [`task::block_in_place`]: crate::task::block_in_place -//! [rt-threaded]: ../runtime/index.html#threaded-scheduler +//! [rt-multi-thread]: ../runtime/index.html#threaded-scheduler //! [`task::yield_now`]: crate::task::yield_now() //! [`thread::yield_now`]: std::thread::yield_now -cfg_blocking! { - mod blocking; - pub use blocking::spawn_blocking; - cfg_rt_threaded! { - pub use blocking::block_in_place; - } -} - -cfg_rt_core! { +cfg_rt! { pub use crate::runtime::task::{JoinError, JoinHandle}; + mod blocking; + pub use blocking::spawn_blocking; + mod spawn; pub use spawn::spawn; + cfg_rt_multi_thread! { + pub use blocking::block_in_place; + } + mod yield_now; pub use yield_now::yield_now; -} -cfg_rt_util! { mod local; pub use local::{spawn_local, LocalSet}; diff --git a/src/task/spawn.rs b/src/task/spawn.rs index d6e7711..77acb57 100644 --- a/src/task/spawn.rs +++ b/src/task/spawn.rs @@ -3,7 +3,7 @@ use crate::task::JoinHandle; use std::future::Future; -doc_rt_core! { +cfg_rt! { /// Spawns a new asynchronous task, returning a /// [`JoinHandle`](super::JoinHandle) for it. /// @@ -18,7 +18,7 @@ doc_rt_core! { /// /// This function must be called from the context of a Tokio runtime. Tasks running on /// the Tokio runtime are always inside its context, but you can also enter the context - /// using the [`Handle::enter`](crate::runtime::Handle::enter()) method. + /// using the [`Runtime::enter`](crate::runtime::Runtime::enter()) method. /// /// # Examples /// @@ -37,7 +37,7 @@ doc_rt_core! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; /// /// loop { /// let (socket, _) = listener.accept().await?; diff --git a/src/task/yield_now.rs b/src/task/yield_now.rs index e0e2084..251cb93 100644 --- a/src/task/yield_now.rs +++ b/src/task/yield_now.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -doc_rt_core! { +cfg_rt! { /// Yields execution back to the Tokio runtime. /// /// A task yields by awaiting on `yield_now()`, and may resume when that diff --git a/src/time/clock.rs b/src/time/clock.rs index bd67d7a..fab7eca 100644 --- a/src/time/clock.rs +++ b/src/time/clock.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + //! Source of time abstraction. //! //! By default, `std::time::Instant::now()` is used. However, when the @@ -36,7 +38,18 @@ cfg_not_test_util! { cfg_test_util! { use crate::time::{Duration, Instant}; use std::sync::{Arc, Mutex}; - use crate::runtime::context; + + cfg_rt! { + fn clock() -> Option<Clock> { + crate::runtime::context::clock() + } + } + + cfg_not_rt! { + fn clock() -> Option<Clock> { + None + } + } /// A handle to a source of time. #[derive(Debug, Clone)] @@ -58,14 +71,14 @@ cfg_test_util! { /// The current value of `Instant::now()` is saved and all subsequent calls /// to `Instant::now()` until the timer wheel is checked again will return the saved value. /// Once the timer wheel is checked, time will immediately advance to the next registered - /// `Delay`. This is useful for running tests that depend on time. + /// `Sleep`. This is useful for running tests that depend on time. /// /// # Panics /// /// Panics if time is already frozen or if called from outside of the Tokio /// runtime. pub fn pause() { - let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime"); + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.pause(); } @@ -79,7 +92,7 @@ cfg_test_util! { /// Panics if time is not frozen or if called from outside of the Tokio /// runtime. pub fn resume() { - let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime"); + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); let mut inner = clock.inner.lock().unwrap(); if inner.unfrozen.is_some() { @@ -99,14 +112,27 @@ cfg_test_util! { /// Panics if time is not frozen or if called from outside of the Tokio /// runtime. pub async fn advance(duration: Duration) { - let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime"); + use crate::future::poll_fn; + use std::task::Poll; + + let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.advance(duration); - crate::task::yield_now().await; + + let mut yielded = false; + poll_fn(|cx| { + if yielded { + Poll::Ready(()) + } else { + yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + }).await; } /// Return the current instant, factoring in frozen time. pub(crate) fn now() -> Instant { - if let Some(clock) = context::clock() { + if let Some(clock) = clock() { clock.now() } else { Instant::from_std(std::time::Instant::now()) diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs deleted file mode 100644 index 55ec7cd..0000000 --- a/src/time/delay_queue.rs +++ /dev/null @@ -1,887 +0,0 @@ -//! A queue of delayed elements. -//! -//! See [`DelayQueue`] for more details. -//! -//! [`DelayQueue`]: struct@DelayQueue - -use crate::time::wheel::{self, Wheel}; -use crate::time::{delay_until, Delay, Duration, Error, Instant}; - -use slab::Slab; -use std::cmp; -use std::future::Future; -use std::marker::PhantomData; -use std::pin::Pin; -use std::task::{self, Poll}; - -/// A queue of delayed elements. -/// -/// Once an element is inserted into the `DelayQueue`, it is yielded once the -/// specified deadline has been reached. -/// -/// # Usage -/// -/// Elements are inserted into `DelayQueue` using the [`insert`] or -/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is -/// returned. The key is used to remove the entry or to change the deadline at -/// which it should be yielded back. -/// -/// Once delays have been configured, the `DelayQueue` is used via its -/// [`Stream`] implementation. [`poll`] is called. If an entry has reached its -/// deadline, it is returned. If not, `Poll::Pending` indicating that the -/// current task will be notified once the deadline has been reached. -/// -/// # `Stream` implementation -/// -/// Items are retrieved from the queue via [`Stream::poll`]. If no delays have -/// expired, no items are returned. In this case, `NotReady` is returned and the -/// current task is registered to be notified once the next item's delay has -/// expired. -/// -/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` -/// returns `Ready(None)`. This indicates that the stream has reached an end. -/// However, if a new item is inserted *after*, `poll` will once again start -/// returning items or `NotReady. -/// -/// Items are returned ordered by their expirations. Items that are configured -/// to expire first will be returned first. There are no ordering guarantees -/// for items configured to expire the same instant. Also note that delays are -/// rounded to the closest millisecond. -/// -/// # Implementation -/// -/// The [`DelayQueue`] is backed by a separate instance of the same timer wheel used internally by -/// Tokio's standalone timer utilities such as [`delay_for`]. Because of this, it offers the same -/// performance and scalability benefits. -/// -/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, -/// and allows reuse of the memory allocated for expired entires. -/// -/// Capacity can be checked using [`capacity`] and allocated preemptively by using -/// the [`reserve`] method. -/// -/// # Usage -/// -/// Using `DelayQueue` to manage cache entries. -/// -/// ```rust,no_run -/// use tokio::time::{delay_queue, DelayQueue, Error}; -/// -/// use futures::ready; -/// use std::collections::HashMap; -/// use std::task::{Context, Poll}; -/// use std::time::Duration; -/// # type CacheKey = String; -/// # type Value = String; -/// -/// struct Cache { -/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>, -/// expirations: DelayQueue<CacheKey>, -/// } -/// -/// const TTL_SECS: u64 = 30; -/// -/// impl Cache { -/// fn insert(&mut self, key: CacheKey, value: Value) { -/// let delay = self.expirations -/// .insert(key.clone(), Duration::from_secs(TTL_SECS)); -/// -/// self.entries.insert(key, (value, delay)); -/// } -/// -/// fn get(&self, key: &CacheKey) -> Option<&Value> { -/// self.entries.get(key) -/// .map(|&(ref v, _)| v) -/// } -/// -/// fn remove(&mut self, key: &CacheKey) { -/// if let Some((_, cache_key)) = self.entries.remove(key) { -/// self.expirations.remove(&cache_key); -/// } -/// } -/// -/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { -/// while let Some(res) = ready!(self.expirations.poll_expired(cx)) { -/// let entry = res?; -/// self.entries.remove(entry.get_ref()); -/// } -/// -/// Poll::Ready(Ok(())) -/// } -/// } -/// ``` -/// -/// [`insert`]: method@Self::insert -/// [`insert_at`]: method@Self::insert_at -/// [`Key`]: struct@Key -/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html -/// [`poll`]: method@Self::poll -/// [`Stream::poll`]: method@Self::poll -/// [`DelayQueue`]: struct@DelayQueue -/// [`delay_for`]: fn@super::delay_for -/// [`slab`]: slab -/// [`capacity`]: method@Self::capacity -/// [`reserve`]: method@Self::reserve -#[derive(Debug)] -pub struct DelayQueue<T> { - /// Stores data associated with entries - slab: Slab<Data<T>>, - - /// Lookup structure tracking all delays in the queue - wheel: Wheel<Stack<T>>, - - /// Delays that were inserted when already expired. These cannot be stored - /// in the wheel - expired: Stack<T>, - - /// Delay expiring when the *first* item in the queue expires - delay: Option<Delay>, - - /// Wheel polling state - poll: wheel::Poll, - - /// Instant at which the timer starts - start: Instant, -} - -/// An entry in `DelayQueue` that has expired and removed. -/// -/// Values are returned by [`DelayQueue::poll`]. -/// -/// [`DelayQueue::poll`]: method@DelayQueue::poll -#[derive(Debug)] -pub struct Expired<T> { - /// The data stored in the queue - data: T, - - /// The expiration time - deadline: Instant, - - /// The key associated with the entry - key: Key, -} - -/// Token to a value stored in a `DelayQueue`. -/// -/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] -/// documentation for more details. -/// -/// [`DelayQueue`]: struct@DelayQueue -/// [`DelayQueue::insert`]: method@DelayQueue::insert -#[derive(Debug, Clone)] -pub struct Key { - index: usize, -} - -#[derive(Debug)] -struct Stack<T> { - /// Head of the stack - head: Option<usize>, - _p: PhantomData<fn() -> T>, -} - -#[derive(Debug)] -struct Data<T> { - /// The data being stored in the queue and will be returned at the requested - /// instant. - inner: T, - - /// The instant at which the item is returned. - when: u64, - - /// Set to true when stored in the `expired` queue - expired: bool, - - /// Next entry in the stack - next: Option<usize>, - - /// Previous entry in the stack - prev: Option<usize>, -} - -/// Maximum number of entries the queue can handle -const MAX_ENTRIES: usize = (1 << 30) - 1; - -impl<T> DelayQueue<T> { - /// Creates a new, empty, `DelayQueue` - /// - /// The queue will not allocate storage until items are inserted into it. - /// - /// # Examples - /// - /// ```rust - /// # use tokio::time::DelayQueue; - /// let delay_queue: DelayQueue<u32> = DelayQueue::new(); - /// ``` - pub fn new() -> DelayQueue<T> { - DelayQueue::with_capacity(0) - } - - /// Creates a new, empty, `DelayQueue` with the specified capacity. - /// - /// The queue will be able to hold at least `capacity` elements without - /// reallocating. If `capacity` is 0, the queue will not allocate for - /// storage. - /// - /// # Examples - /// - /// ```rust - /// # use tokio::time::DelayQueue; - /// # use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::with_capacity(10); - /// - /// // These insertions are done without further allocation - /// for i in 0..10 { - /// delay_queue.insert(i, Duration::from_secs(i)); - /// } - /// - /// // This will make the queue allocate additional storage - /// delay_queue.insert(11, Duration::from_secs(11)); - /// # } - /// ``` - pub fn with_capacity(capacity: usize) -> DelayQueue<T> { - DelayQueue { - wheel: Wheel::new(), - slab: Slab::with_capacity(capacity), - expired: Stack::default(), - delay: None, - poll: wheel::Poll::new(0), - start: Instant::now(), - } - } - - /// Inserts `value` into the queue set to expire at a specific instant in - /// time. - /// - /// This function is identical to `insert`, but takes an `Instant` instead - /// of a `Duration`. - /// - /// `value` is stored in the queue until `when` is reached. At which point, - /// `value` will be returned from [`poll`]. If `when` has already been - /// reached, then `value` is immediately made available to poll. - /// - /// The return value represents the insertion and is used at an argument to - /// [`remove`] and [`reset`]. Note that [`Key`] is token and is reused once - /// `value` is removed from the queue either by calling [`poll`] after - /// `when` is reached or by calling [`remove`]. At this point, the caller - /// must take care to not use the returned [`Key`] again as it may reference - /// a different item in the queue. - /// - /// See [type] level documentation for more details. - /// - /// # Panics - /// - /// This function panics if `when` is too far in the future. - /// - /// # Examples - /// - /// Basic usage - /// - /// ```rust - /// use tokio::time::{DelayQueue, Duration, Instant}; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// let key = delay_queue.insert_at( - /// "foo", Instant::now() + Duration::from_secs(5)); - /// - /// // Remove the entry - /// let item = delay_queue.remove(&key); - /// assert_eq!(*item.get_ref(), "foo"); - /// # } - /// ``` - /// - /// [`poll`]: method@Self::poll - /// [`remove`]: method@Self::remove - /// [`reset`]: method@Self::reset - /// [`Key`]: struct@Key - /// [type]: # - pub fn insert_at(&mut self, value: T, when: Instant) -> Key { - assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); - - // Normalize the deadline. Values cannot be set to expire in the past. - let when = self.normalize_deadline(when); - - // Insert the value in the store - let key = self.slab.insert(Data { - inner: value, - when, - expired: false, - next: None, - prev: None, - }); - - self.insert_idx(when, key); - - // Set a new delay if the current's deadline is later than the one of the new item - let should_set_delay = if let Some(ref delay) = self.delay { - let current_exp = self.normalize_deadline(delay.deadline()); - current_exp > when - } else { - true - }; - - if should_set_delay { - let delay_time = self.start + Duration::from_millis(when); - if let Some(ref mut delay) = &mut self.delay { - delay.reset(delay_time); - } else { - self.delay = Some(delay_until(delay_time)); - } - } - - Key::new(key) - } - - /// Attempts to pull out the next value of the delay queue, registering the - /// current task for wakeup if the value is not yet available, and returning - /// None if the queue is exhausted. - pub fn poll_expired( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll<Option<Result<Expired<T>, Error>>> { - let item = ready!(self.poll_idx(cx)); - Poll::Ready(item.map(|result| { - result.map(|idx| { - let data = self.slab.remove(idx); - debug_assert!(data.next.is_none()); - debug_assert!(data.prev.is_none()); - - Expired { - key: Key::new(idx), - data: data.inner, - deadline: self.start + Duration::from_millis(data.when), - } - }) - })) - } - - /// Inserts `value` into the queue set to expire after the requested duration - /// elapses. - /// - /// This function is identical to `insert_at`, but takes a `Duration` - /// instead of an `Instant`. - /// - /// `value` is stored in the queue until `when` is reached. At which point, - /// `value` will be returned from [`poll`]. If `when` has already been - /// reached, then `value` is immediately made available to poll. - /// - /// The return value represents the insertion and is used at an argument to - /// [`remove`] and [`reset`]. Note that [`Key`] is token and is reused once - /// `value` is removed from the queue either by calling [`poll`] after - /// `when` is reached or by calling [`remove`]. At this point, the caller - /// must take care to not use the returned [`Key`] again as it may reference - /// a different item in the queue. - /// - /// See [type] level documentation for more details. - /// - /// # Panics - /// - /// This function panics if `timeout` is greater than the maximum supported - /// duration. - /// - /// # Examples - /// - /// Basic usage - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// let key = delay_queue.insert("foo", Duration::from_secs(5)); - /// - /// // Remove the entry - /// let item = delay_queue.remove(&key); - /// assert_eq!(*item.get_ref(), "foo"); - /// # } - /// ``` - /// - /// [`poll`]: method@Self::poll - /// [`remove`]: method@Self::remove - /// [`reset`]: method@Self::reset - /// [`Key`]: struct@Key - /// [type]: # - pub fn insert(&mut self, value: T, timeout: Duration) -> Key { - self.insert_at(value, Instant::now() + timeout) - } - - fn insert_idx(&mut self, when: u64, key: usize) { - use self::wheel::{InsertError, Stack}; - - // Register the deadline with the timer wheel - match self.wheel.insert(when, key, &mut self.slab) { - Ok(_) => {} - Err((_, InsertError::Elapsed)) => { - self.slab[key].expired = true; - // The delay is already expired, store it in the expired queue - self.expired.push(key, &mut self.slab); - } - Err((_, err)) => panic!("invalid deadline; err={:?}", err), - } - } - - /// Removes the item associated with `key` from the queue. - /// - /// There must be an item associated with `key`. The function returns the - /// removed item as well as the `Instant` at which it will the delay will - /// have expired. - /// - /// # Panics - /// - /// The function panics if `key` is not contained by the queue. - /// - /// # Examples - /// - /// Basic usage - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// let key = delay_queue.insert("foo", Duration::from_secs(5)); - /// - /// // Remove the entry - /// let item = delay_queue.remove(&key); - /// assert_eq!(*item.get_ref(), "foo"); - /// # } - /// ``` - pub fn remove(&mut self, key: &Key) -> Expired<T> { - use crate::time::wheel::Stack; - - // Special case the `expired` queue - if self.slab[key.index].expired { - self.expired.remove(&key.index, &mut self.slab); - } else { - self.wheel.remove(&key.index, &mut self.slab); - } - - let data = self.slab.remove(key.index); - - Expired { - key: Key::new(key.index), - data: data.inner, - deadline: self.start + Duration::from_millis(data.when), - } - } - - /// Sets the delay of the item associated with `key` to expire at `when`. - /// - /// This function is identical to `reset` but takes an `Instant` instead of - /// a `Duration`. - /// - /// The item remains in the queue but the delay is set to expire at `when`. - /// If `when` is in the past, then the item is immediately made available to - /// the caller. - /// - /// # Panics - /// - /// This function panics if `when` is too far in the future or if `key` is - /// not contained by the queue. - /// - /// # Examples - /// - /// Basic usage - /// - /// ```rust - /// use tokio::time::{DelayQueue, Duration, Instant}; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// let key = delay_queue.insert("foo", Duration::from_secs(5)); - /// - /// // "foo" is scheduled to be returned in 5 seconds - /// - /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); - /// - /// // "foo"is now scheduled to be returned in 10 seconds - /// # } - /// ``` - pub fn reset_at(&mut self, key: &Key, when: Instant) { - self.wheel.remove(&key.index, &mut self.slab); - - // Normalize the deadline. Values cannot be set to expire in the past. - let when = self.normalize_deadline(when); - - self.slab[key.index].when = when; - self.insert_idx(when, key.index); - - let next_deadline = self.next_deadline(); - if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { - delay.reset(deadline); - } - } - - /// Returns the next time poll as determined by the wheel - fn next_deadline(&mut self) -> Option<Instant> { - self.wheel - .poll_at() - .map(|poll_at| self.start + Duration::from_millis(poll_at)) - } - - /// Sets the delay of the item associated with `key` to expire after - /// `timeout`. - /// - /// This function is identical to `reset_at` but takes a `Duration` instead - /// of an `Instant`. - /// - /// The item remains in the queue but the delay is set to expire after - /// `timeout`. If `timeout` is zero, then the item is immediately made - /// available to the caller. - /// - /// # Panics - /// - /// This function panics if `timeout` is greater than the maximum supported - /// duration or if `key` is not contained by the queue. - /// - /// # Examples - /// - /// Basic usage - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// let key = delay_queue.insert("foo", Duration::from_secs(5)); - /// - /// // "foo" is scheduled to be returned in 5 seconds - /// - /// delay_queue.reset(&key, Duration::from_secs(10)); - /// - /// // "foo"is now scheduled to be returned in 10 seconds - /// # } - /// ``` - pub fn reset(&mut self, key: &Key, timeout: Duration) { - self.reset_at(key, Instant::now() + timeout); - } - - /// Clears the queue, removing all items. - /// - /// After calling `clear`, [`poll`] will return `Ok(Ready(None))`. - /// - /// Note that this method has no effect on the allocated capacity. - /// - /// [`poll`]: method@Self::poll - /// - /// # Examples - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// - /// delay_queue.insert("foo", Duration::from_secs(5)); - /// - /// assert!(!delay_queue.is_empty()); - /// - /// delay_queue.clear(); - /// - /// assert!(delay_queue.is_empty()); - /// # } - /// ``` - pub fn clear(&mut self) { - self.slab.clear(); - self.expired = Stack::default(); - self.wheel = Wheel::new(); - self.delay = None; - } - - /// Returns the number of elements the queue can hold without reallocating. - /// - /// # Examples - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// - /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); - /// assert_eq!(delay_queue.capacity(), 10); - /// ``` - pub fn capacity(&self) -> usize { - self.slab.capacity() - } - - /// Returns the number of elements currently in the queue. - /// - /// # Examples - /// - /// ```rust - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10); - /// assert_eq!(delay_queue.len(), 0); - /// delay_queue.insert(3, Duration::from_secs(5)); - /// assert_eq!(delay_queue.len(), 1); - /// # } - /// ``` - pub fn len(&self) -> usize { - self.slab.len() - } - - /// Reserves capacity for at least `additional` more items to be queued - /// without allocating. - /// - /// `reserve` does nothing if the queue already has sufficient capacity for - /// `additional` more values. If more capacity is required, a new segment of - /// memory will be allocated and all existing values will be copied into it. - /// As such, if the queue is already very large, a call to `reserve` can end - /// up being expensive. - /// - /// The queue may reserve more than `additional` extra space in order to - /// avoid frequent reallocations. - /// - /// # Panics - /// - /// Panics if the new capacity exceeds the maximum number of entries the - /// queue can contain. - /// - /// # Examples - /// - /// ``` - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// - /// delay_queue.insert("hello", Duration::from_secs(10)); - /// delay_queue.reserve(10); - /// - /// assert!(delay_queue.capacity() >= 11); - /// # } - /// ``` - pub fn reserve(&mut self, additional: usize) { - self.slab.reserve(additional); - } - - /// Returns `true` if there are no items in the queue. - /// - /// Note that this function returns `false` even if all items have not yet - /// expired and a call to `poll` will return `NotReady`. - /// - /// # Examples - /// - /// ``` - /// use tokio::time::DelayQueue; - /// use std::time::Duration; - /// - /// # #[tokio::main] - /// # async fn main() { - /// let mut delay_queue = DelayQueue::new(); - /// assert!(delay_queue.is_empty()); - /// - /// delay_queue.insert("hello", Duration::from_secs(5)); - /// assert!(!delay_queue.is_empty()); - /// # } - /// ``` - pub fn is_empty(&self) -> bool { - self.slab.is_empty() - } - - /// Polls the queue, returning the index of the next slot in the slab that - /// should be returned. - /// - /// A slot should be returned when the associated deadline has been reached. - fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Result<usize, Error>>> { - use self::wheel::Stack; - - let expired = self.expired.pop(&mut self.slab); - - if expired.is_some() { - return Poll::Ready(expired.map(Ok)); - } - - loop { - if let Some(ref mut delay) = self.delay { - if !delay.is_elapsed() { - ready!(Pin::new(&mut *delay).poll(cx)); - } - - let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); - - self.poll = wheel::Poll::new(now); - } - - // We poll the wheel to get the next value out before finding the next deadline. - let wheel_idx = self.wheel.poll(&mut self.poll, &mut self.slab); - - self.delay = self.next_deadline().map(delay_until); - - if let Some(idx) = wheel_idx { - return Poll::Ready(Some(Ok(idx))); - } - - if self.delay.is_none() { - return Poll::Ready(None); - } - } - } - - fn normalize_deadline(&self, when: Instant) -> u64 { - let when = if when < self.start { - 0 - } else { - crate::time::ms(when - self.start, crate::time::Round::Up) - }; - - cmp::max(when, self.wheel.elapsed()) - } -} - -// We never put `T` in a `Pin`... -impl<T> Unpin for DelayQueue<T> {} - -impl<T> Default for DelayQueue<T> { - fn default() -> DelayQueue<T> { - DelayQueue::new() - } -} - -#[cfg(feature = "stream")] -impl<T> futures_core::Stream for DelayQueue<T> { - // DelayQueue seems much more specific, where a user may care that it - // has reached capacity, so return those errors instead of panicking. - type Item = Result<Expired<T>, Error>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { - DelayQueue::poll_expired(self.get_mut(), cx) - } -} - -impl<T> wheel::Stack for Stack<T> { - type Owned = usize; - type Borrowed = usize; - type Store = Slab<Data<T>>; - - fn is_empty(&self) -> bool { - self.head.is_none() - } - - fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { - // Ensure the entry is not already in a stack. - debug_assert!(store[item].next.is_none()); - debug_assert!(store[item].prev.is_none()); - - // Remove the old head entry - let old = self.head.take(); - - if let Some(idx) = old { - store[idx].prev = Some(item); - } - - store[item].next = old; - self.head = Some(item) - } - - fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> { - if let Some(idx) = self.head { - self.head = store[idx].next; - - if let Some(idx) = self.head { - store[idx].prev = None; - } - - store[idx].next = None; - debug_assert!(store[idx].prev.is_none()); - - Some(idx) - } else { - None - } - } - - fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { - assert!(store.contains(*item)); - - // Ensure that the entry is in fact contained by the stack - debug_assert!({ - // This walks the full linked list even if an entry is found. - let mut next = self.head; - let mut contains = false; - - while let Some(idx) = next { - if idx == *item { - debug_assert!(!contains); - contains = true; - } - - next = store[idx].next; - } - - contains - }); - - if let Some(next) = store[*item].next { - store[next].prev = store[*item].prev; - } - - if let Some(prev) = store[*item].prev { - store[prev].next = store[*item].next; - } else { - self.head = store[*item].next; - } - - store[*item].next = None; - store[*item].prev = None; - } - - fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { - store[*item].when - } -} - -impl<T> Default for Stack<T> { - fn default() -> Stack<T> { - Stack { - head: None, - _p: PhantomData, - } - } -} - -impl Key { - pub(crate) fn new(index: usize) -> Key { - Key { index } - } -} - -impl<T> Expired<T> { - /// Returns a reference to the inner value. - pub fn get_ref(&self) -> &T { - &self.data - } - - /// Returns a mutable reference to the inner value. - pub fn get_mut(&mut self) -> &mut T { - &mut self.data - } - - /// Consumes `self` and returns the inner value. - pub fn into_inner(self) -> T { - self.data - } - - /// Returns the deadline that the expiration was set to. - pub fn deadline(&self) -> Instant { - self.deadline - } -} diff --git a/src/time/driver/atomic_stack.rs b/src/time/driver/atomic_stack.rs index 7e5a83f..5dcc472 100644 --- a/src/time/driver/atomic_stack.rs +++ b/src/time/driver/atomic_stack.rs @@ -1,5 +1,5 @@ use crate::time::driver::Entry; -use crate::time::Error; +use crate::time::error::Error; use std::ptr; use std::sync::atomic::AtomicPtr; @@ -95,7 +95,7 @@ impl Iterator for AtomicStackEntries { type Item = Arc<Entry>; fn next(&mut self) -> Option<Self::Item> { - if self.ptr.is_null() { + if self.ptr.is_null() || self.ptr == SHUTDOWN { return None; } @@ -118,7 +118,7 @@ impl Drop for AtomicStackEntries { fn drop(&mut self) { for entry in self { // Flag the entry as errored - entry.error(); + entry.error(Error::shutdown()); } } } diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs index b375ee9..b40cae7 100644 --- a/src/time/driver/entry.rs +++ b/src/time/driver/entry.rs @@ -1,17 +1,17 @@ use crate::loom::sync::atomic::AtomicU64; use crate::sync::AtomicWaker; use crate::time::driver::{Handle, Inner}; -use crate::time::{Duration, Error, Instant}; +use crate::time::{error::Error, Duration, Instant}; use std::cell::UnsafeCell; use std::ptr; -use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::{AtomicBool, AtomicU8}; use std::sync::{Arc, Weak}; use std::task::{self, Poll}; use std::u64; -/// Internal state shared between a `Delay` instance and the timer. +/// Internal state shared between a `Sleep` instance and the timer. /// /// This struct is used as a node in two intrusive data structures: /// @@ -28,7 +28,7 @@ pub(crate) struct Entry { time: CachePadded<UnsafeCell<Time>>, /// Timer internals. Using a weak pointer allows the timer to shutdown - /// without all `Delay` instances having completed. + /// without all `Sleep` instances having completed. /// /// When empty, it means that the entry has not yet been linked with a /// timer instance. @@ -45,6 +45,11 @@ pub(crate) struct Entry { /// instant, this value is changed. state: AtomicU64, + /// Stores the actual error. If `state` indicates that an error occurred, + /// this is guaranteed to be a non-zero value representing the first error + /// that occurred. Otherwise its value is undefined. + error: AtomicU8, + /// Task to notify once the deadline is reached. waker: AtomicWaker, @@ -64,8 +69,8 @@ pub(crate) struct Entry { /// When the entry expires, relative to the `start` of the timer /// (Inner::start). This is only used by the timer. /// - /// A `Delay` instance can be reset to a different deadline by the thread - /// that owns the `Delay` instance. In this case, the timer thread will not + /// A `Sleep` instance can be reset to a different deadline by the thread + /// that owns the `Sleep` instance. In this case, the timer thread will not /// immediately know that this has happened. The timer thread must know the /// last deadline that it saw as it uses this value to locate the entry in /// its wheel. @@ -78,7 +83,7 @@ pub(crate) struct Entry { /// Next entry in the State's linked list. /// /// This is only accessed by the timer - pub(super) next_stack: UnsafeCell<Option<Arc<Entry>>>, + pub(crate) next_stack: UnsafeCell<Option<Arc<Entry>>>, /// Previous entry in the State's linked list. /// @@ -86,10 +91,10 @@ pub(crate) struct Entry { /// entry. /// /// This is a weak reference. - pub(super) prev_stack: UnsafeCell<*const Entry>, + pub(crate) prev_stack: UnsafeCell<*const Entry>, } -/// Stores the info for `Delay`. +/// Stores the info for `Sleep`. #[derive(Debug)] pub(crate) struct Time { pub(crate) deadline: Instant, @@ -107,11 +112,12 @@ const ERROR: u64 = u64::MAX; impl Entry { pub(crate) fn new(handle: &Handle, deadline: Instant, duration: Duration) -> Arc<Entry> { let inner = handle.inner().unwrap(); - let entry: Entry; - // Increment the number of active timeouts - if inner.increment().is_err() { - entry = Entry::new2(deadline, duration, Weak::new(), ERROR) + // Attempt to increment the number of active timeouts + let entry = if let Err(err) = inner.increment() { + let entry = Entry::new2(deadline, duration, Weak::new(), ERROR); + entry.error(err); + entry } else { let when = inner.normalize_deadline(deadline); let state = if when <= inner.elapsed() { @@ -119,12 +125,12 @@ impl Entry { } else { when }; - entry = Entry::new2(deadline, duration, Arc::downgrade(&inner), state); - } + Entry::new2(deadline, duration, Arc::downgrade(&inner), state) + }; let entry = Arc::new(entry); - if inner.queue(&entry).is_err() { - entry.error(); + if let Err(err) = inner.queue(&entry) { + entry.error(err); } entry @@ -141,6 +147,10 @@ impl Entry { &mut *self.time.0.get() } + pub(crate) fn when(&self) -> u64 { + self.when_internal().expect("invalid internal state") + } + /// The current entry state as known by the timer. This is not the value of /// `state`, but lets the timer know how to converge its state to `state`. pub(crate) fn when_internal(&self) -> Option<u64> { @@ -190,7 +200,12 @@ impl Entry { self.waker.wake(); } - pub(crate) fn error(&self) { + pub(crate) fn error(&self, error: Error) { + // Record the precise nature of the error, if there isn't already an + // error present. If we don't actually transition to the error state + // below, that's fine, as the error details we set here will be ignored. + self.error.compare_and_swap(0, error.as_u8(), SeqCst); + // Only transition to the error state if not currently elapsed let mut curr = self.state.load(SeqCst); @@ -235,7 +250,7 @@ impl Entry { if is_elapsed(curr) { return Poll::Ready(if curr == ERROR { - Err(Error::shutdown()) + Err(Error::from_u8(self.error.load(SeqCst))) } else { Ok(()) }); @@ -247,7 +262,7 @@ impl Entry { if is_elapsed(curr) { return Poll::Ready(if curr == ERROR { - Err(Error::shutdown()) + Err(Error::from_u8(self.error.load(SeqCst))) } else { Ok(()) }); @@ -310,6 +325,7 @@ impl Entry { waker: AtomicWaker::new(), state: AtomicU64::new(state), queued: AtomicBool::new(false), + error: AtomicU8::new(0), next_atomic: UnsafeCell::new(ptr::null_mut()), when: UnsafeCell::new(None), next_stack: UnsafeCell::new(None), diff --git a/src/time/driver/handle.rs b/src/time/driver/handle.rs index 38b1761..54b8a8b 100644 --- a/src/time/driver/handle.rs +++ b/src/time/driver/handle.rs @@ -1,4 +1,3 @@ -use crate::runtime::context; use crate::time::driver::Inner; use std::fmt; use std::sync::{Arc, Weak}; @@ -15,22 +14,62 @@ impl Handle { Handle { inner } } - /// Tries to get a handle to the current timer. - /// - /// # Panics - /// - /// This function panics if there is no current timer set. - pub(crate) fn current() -> Self { - context::time_handle() - .expect("there is no timer running, must be called from the context of Tokio runtime") - } - /// Tries to return a strong ref to the inner pub(crate) fn inner(&self) -> Option<Arc<Inner>> { self.inner.upgrade() } } +cfg_rt! { + impl Handle { + /// Tries to get a handle to the current timer. + /// + /// # Panics + /// + /// This function panics if there is no current timer set. + /// + /// It can be triggered when `Builder::enable_time()` or + /// `Builder::enable_all()` are not included in the builder. + /// + /// It can also panic whenever a timer is created outside of a Tokio + /// runtime. That is why `rt.block_on(delay_for(...))` will panic, + /// since the function is executed outside of the runtime. + /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't + /// panic. And this is because wrapping the function on an async makes it + /// lazy, and so gets executed inside the runtime successfuly without + /// panicking. + pub(crate) fn current() -> Self { + crate::runtime::context::time_handle() + .expect("there is no timer running, must be called from the context of Tokio runtime") + } + } +} + +cfg_not_rt! { + impl Handle { + /// Tries to get a handle to the current timer. + /// + /// # Panics + /// + /// This function panics if there is no current timer set. + /// + /// It can be triggered when `Builder::enable_time()` or + /// `Builder::enable_all()` are not included in the builder. + /// + /// It can also panic whenever a timer is created outside of a Tokio + /// runtime. That is why `rt.block_on(delay_for(...))` will panic, + /// since the function is executed outside of the runtime. + /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't + /// panic. And this is because wrapping the function on an async makes it + /// lazy, and so gets executed inside the runtime successfuly without + /// panicking. + pub(crate) fn current() -> Self { + panic!("there is no timer running, must be called from the context of Tokio runtime or \ + `rt` is not enabled") + } + } +} + impl fmt::Debug for Handle { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Handle") diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs index 554042f..8532c55 100644 --- a/src/time/driver/mod.rs +++ b/src/time/driver/mod.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + //! Time driver mod atomic_stack; @@ -9,15 +11,9 @@ pub(super) use self::entry::Entry; mod handle; pub(crate) use self::handle::Handle; -mod registration; -pub(crate) use self::registration::Registration; - -mod stack; -use self::stack::Stack; - use crate::loom::sync::atomic::{AtomicU64, AtomicUsize}; use crate::park::{Park, Unpark}; -use crate::time::{wheel, Error}; +use crate::time::{error::Error, wheel}; use crate::time::{Clock, Duration, Instant}; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst}; @@ -26,12 +22,12 @@ use std::sync::Arc; use std::usize; use std::{cmp, fmt}; -/// Time implementation that drives [`Delay`][delay], [`Interval`][interval], and [`Timeout`][timeout]. +/// Time implementation that drives [`Sleep`][sleep], [`Interval`][interval], and [`Timeout`][timeout]. /// /// A `Driver` instance tracks the state necessary for managing time and -/// notifying the [`Delay`][delay] instances once their deadlines are reached. +/// notifying the [`Sleep`][sleep] instances once their deadlines are reached. /// -/// It is expected that a single instance manages many individual [`Delay`][delay] +/// It is expected that a single instance manages many individual [`Sleep`][sleep] /// instances. The `Driver` implementation is thread-safe and, as such, is able /// to handle callers from across threads. /// @@ -42,9 +38,9 @@ use std::{cmp, fmt}; /// The driver has a resolution of one millisecond. Any unit of time that falls /// between milliseconds are rounded up to the next millisecond. /// -/// When an instance is dropped, any outstanding [`Delay`][delay] instance that has not +/// When an instance is dropped, any outstanding [`Sleep`][sleep] instance that has not /// elapsed will be notified with an error. At this point, calling `poll` on the -/// [`Delay`][delay] instance will result in panic. +/// [`Sleep`][sleep] instance will result in panic. /// /// # Implementation /// @@ -71,29 +67,32 @@ use std::{cmp, fmt}; /// * Level 5: 64 x ~12 day slots. /// /// When the timer processes entries at level zero, it will notify all the -/// `Delay` instances as their deadlines have been reached. For all higher +/// `Sleep` instances as their deadlines have been reached. For all higher /// levels, all entries will be redistributed across the wheel at the next level -/// down. Eventually, as time progresses, entries will [`Delay`][delay] instances will +/// down. Eventually, as time progresses, entries with [`Sleep`][sleep] instances will /// either be canceled (dropped) or their associated entries will reach level /// zero and be notified. /// /// [paper]: http://www.cs.columbia.edu/~nahum/w6998/papers/ton97-timing-wheels.pdf -/// [delay]: crate::time::Delay +/// [sleep]: crate::time::Sleep /// [timeout]: crate::time::Timeout /// [interval]: crate::time::Interval #[derive(Debug)] -pub(crate) struct Driver<T> { +pub(crate) struct Driver<T: Park> { /// Shared state inner: Arc<Inner>, /// Timer wheel - wheel: wheel::Wheel<Stack>, + wheel: wheel::Wheel, /// Thread parker. The `Driver` park implementation delegates to this. park: T, /// Source of "now" instances clock: Clock, + + /// True if the driver is being shutdown + is_shutdown: bool, } /// Timer state shared between `Driver`, `Handle`, and `Registration`. @@ -135,12 +134,13 @@ where wheel: wheel::Wheel::new(), park, clock, + is_shutdown: false, } } /// Returns a handle to the timer. /// - /// The `Handle` is how `Delay` instances are created. The `Delay` instances + /// The `Handle` is how `Sleep` instances are created. The `Sleep` instances /// can either be created directly or the `Handle` instance can be passed to /// `with_default`, setting the timer as the default timer for the execution /// context. @@ -159,9 +159,8 @@ where self.clock.now() - self.inner.start, crate::time::Round::Down, ); - let mut poll = wheel::Poll::new(now); - while let Some(entry) = self.wheel.poll(&mut poll, &mut ()) { + while let Some(entry) = self.wheel.poll(now) { let when = entry.when_internal().expect("invalid internal entry state"); // Fire the entry @@ -189,7 +188,7 @@ where self.clear_entry(&entry); } (None, Some(when)) => { - // Queue the entry + // Add the entry to the timer wheel self.add_entry(entry, when); } (Some(_), Some(next)) => { @@ -201,19 +200,17 @@ where } fn clear_entry(&mut self, entry: &Arc<Entry>) { - self.wheel.remove(entry, &mut ()); + self.wheel.remove(entry); entry.set_when_internal(None); } /// Fires the entry if it needs to, otherwise queue it to be processed later. - /// - /// Returns `None` if the entry was fired. fn add_entry(&mut self, entry: Arc<Entry>, when: u64) { - use crate::time::wheel::InsertError; + use crate::time::error::InsertError; entry.set_when_internal(Some(when)); - match self.wheel.insert(when, entry, &mut ()) { + match self.wheel.insert(when, entry) { Ok(_) => {} Err((entry, InsertError::Elapsed)) => { // The entry's deadline has elapsed, so fire it and update the @@ -225,7 +222,7 @@ where // The entry's deadline is invalid, so error it and update the // internal state accordingly. entry.set_when_internal(None); - entry.error(); + entry.error(Error::invalid()); } } } @@ -303,10 +300,12 @@ where Ok(()) } -} -impl<T> Drop for Driver<T> { - fn drop(&mut self) { + fn shutdown(&mut self) { + if self.is_shutdown { + return; + } + use std::u64; // Shutdown the stack of entries to process, preventing any new entries @@ -314,11 +313,24 @@ impl<T> Drop for Driver<T> { self.inner.process.shutdown(); // Clear the wheel, using u64::MAX allows us to drain everything - let mut poll = wheel::Poll::new(u64::MAX); + let end_of_time = u64::MAX; - while let Some(entry) = self.wheel.poll(&mut poll, &mut ()) { - entry.error(); + while let Some(entry) = self.wheel.poll(end_of_time) { + entry.error(Error::shutdown()); } + + self.park.shutdown(); + + self.is_shutdown = true; + } +} + +impl<T> Drop for Driver<T> +where + T: Park, +{ + fn drop(&mut self) { + self.shutdown(); } } @@ -368,6 +380,10 @@ impl Inner { debug_assert!(prev <= MAX_TIMEOUTS); } + /// add the entry to the "process queue". entries are not immediately + /// pushed into the timer wheel but are instead pushed into the + /// process queue and then moved from the process queue into the timer + /// wheel on next `process` fn queue(&self, entry: &Arc<Entry>) -> Result<(), Error> { if self.process.push(entry)? { // The timer is notified so that it can process the timeout diff --git a/src/time/driver/registration.rs b/src/time/driver/registration.rs deleted file mode 100644 index 3a0b345..0000000 --- a/src/time/driver/registration.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::time::driver::{Entry, Handle}; -use crate::time::{Duration, Error, Instant}; - -use std::sync::Arc; -use std::task::{self, Poll}; - -/// Registration with a timer. -/// -/// The association between a `Delay` instance and a timer is done lazily in -/// `poll` -#[derive(Debug)] -pub(crate) struct Registration { - entry: Arc<Entry>, -} - -impl Registration { - pub(crate) fn new(deadline: Instant, duration: Duration) -> Registration { - let handle = Handle::current(); - - Registration { - entry: Entry::new(&handle, deadline, duration), - } - } - - pub(crate) fn deadline(&self) -> Instant { - self.entry.time_ref().deadline - } - - pub(crate) fn reset(&mut self, deadline: Instant) { - unsafe { - self.entry.time_mut().deadline = deadline; - } - - Entry::reset(&mut self.entry); - } - - pub(crate) fn is_elapsed(&self) -> bool { - self.entry.is_elapsed() - } - - pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - self.entry.poll_elapsed(cx).map(move |r| { - coop.made_progress(); - r - }) - } -} - -impl Drop for Registration { - fn drop(&mut self) { - Entry::cancel(&self.entry); - } -} diff --git a/src/time/driver/stack.rs b/src/time/driver/stack.rs deleted file mode 100644 index 3e2924f..0000000 --- a/src/time/driver/stack.rs +++ /dev/null @@ -1,121 +0,0 @@ -use crate::time::driver::Entry; -use crate::time::wheel; - -use std::ptr; -use std::sync::Arc; - -/// A doubly linked stack -#[derive(Debug)] -pub(crate) struct Stack { - head: Option<Arc<Entry>>, -} - -impl Default for Stack { - fn default() -> Stack { - Stack { head: None } - } -} - -impl wheel::Stack for Stack { - type Owned = Arc<Entry>; - type Borrowed = Entry; - type Store = (); - - fn is_empty(&self) -> bool { - self.head.is_none() - } - - fn push(&mut self, entry: Self::Owned, _: &mut Self::Store) { - // Get a pointer to the entry to for the prev link - let ptr: *const Entry = &*entry as *const _; - - // Remove the old head entry - let old = self.head.take(); - - unsafe { - // Ensure the entry is not already in a stack. - debug_assert!((*entry.next_stack.get()).is_none()); - debug_assert!((*entry.prev_stack.get()).is_null()); - - if let Some(ref entry) = old.as_ref() { - debug_assert!({ - // The head is not already set to the entry - ptr != &***entry as *const _ - }); - - // Set the previous link on the old head - *entry.prev_stack.get() = ptr; - } - - // Set this entry's next pointer - *entry.next_stack.get() = old; - } - - // Update the head pointer - self.head = Some(entry); - } - - /// Pops an item from the stack - fn pop(&mut self, _: &mut ()) -> Option<Arc<Entry>> { - let entry = self.head.take(); - - unsafe { - if let Some(entry) = entry.as_ref() { - self.head = (*entry.next_stack.get()).take(); - - if let Some(entry) = self.head.as_ref() { - *entry.prev_stack.get() = ptr::null(); - } - - *entry.prev_stack.get() = ptr::null(); - } - } - - entry - } - - fn remove(&mut self, entry: &Entry, _: &mut ()) { - unsafe { - // Ensure that the entry is in fact contained by the stack - debug_assert!({ - // This walks the full linked list even if an entry is found. - let mut next = self.head.as_ref(); - let mut contains = false; - - while let Some(n) = next { - if entry as *const _ == &**n as *const _ { - debug_assert!(!contains); - contains = true; - } - - next = (*n.next_stack.get()).as_ref(); - } - - contains - }); - - // Unlink `entry` from the next node - let next = (*entry.next_stack.get()).take(); - - if let Some(next) = next.as_ref() { - (*next.prev_stack.get()) = *entry.prev_stack.get(); - } - - // Unlink `entry` from the prev node - - if let Some(prev) = (*entry.prev_stack.get()).as_ref() { - *prev.next_stack.get() = next; - } else { - // It is the head - self.head = next; - } - - // Unset the prev pointer - *entry.prev_stack.get() = ptr::null(); - } - } - - fn when(item: &Entry, _: &()) -> u64 { - item.when_internal().expect("invalid internal state") - } -} diff --git a/src/time/error.rs b/src/time/error.rs index 0667b97..24395c4 100644 --- a/src/time/error.rs +++ b/src/time/error.rs @@ -1,3 +1,5 @@ +//! Time error types. + use self::Kind::*; use std::error; use std::fmt; @@ -13,7 +15,7 @@ use std::fmt; /// succeed in the future. /// /// * `at_capacity` occurs when a timer operation is attempted, but the timer -/// instance is currently handling its maximum number of outstanding delays. +/// instance is currently handling its maximum number of outstanding sleep instances. /// In this case, the operation is not able to be performed at the current /// moment, and `at_capacity` is returned. This is a transient error, i.e., at /// some point in the future, if the operation is attempted again, it might @@ -24,12 +26,26 @@ use std::fmt; #[derive(Debug)] pub struct Error(Kind); -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] +#[repr(u8)] enum Kind { - Shutdown, - AtCapacity, + Shutdown = 1, + AtCapacity = 2, + Invalid = 3, } +/// Error returned by `Timeout`. +#[derive(Debug, PartialEq)] +pub struct Elapsed(()); + +#[derive(Debug)] +pub(crate) enum InsertError { + Elapsed, + Invalid, +} + +// ===== impl Error ===== + impl Error { /// Creates an error representing a shutdown timer. pub fn shutdown() -> Error { @@ -38,10 +54,7 @@ impl Error { /// Returns `true` if the error was caused by the timer being shutdown. pub fn is_shutdown(&self) -> bool { - match self.0 { - Kind::Shutdown => true, - _ => false, - } + matches!(self.0, Kind::Shutdown) } /// Creates an error representing a timer at capacity. @@ -51,10 +64,30 @@ impl Error { /// Returns `true` if the error was caused by the timer being at capacity. pub fn is_at_capacity(&self) -> bool { - match self.0 { - Kind::AtCapacity => true, - _ => false, - } + matches!(self.0, Kind::AtCapacity) + } + + /// Create an error representing a misconfigured timer. + pub fn invalid() -> Error { + Error(Invalid) + } + + /// Returns `true` if the error was caused by the timer being misconfigured. + pub fn is_invalid(&self) -> bool { + matches!(self.0, Kind::Invalid) + } + + pub(crate) fn as_u8(&self) -> u8 { + self.0 as u8 + } + + pub(crate) fn from_u8(n: u8) -> Self { + Error(match n { + 1 => Shutdown, + 2 => AtCapacity, + 3 => Invalid, + _ => panic!("u8 does not correspond to any time error variant"), + }) } } @@ -66,7 +99,30 @@ impl fmt::Display for Error { let descr = match self.0 { Shutdown => "the timer is shutdown, must be called from the context of Tokio runtime", AtCapacity => "timer is at capacity and cannot create a new entry", + Invalid => "timer duration exceeds maximum duration", }; write!(fmt, "{}", descr) } } + +// ===== impl Elapsed ===== + +impl Elapsed { + pub(crate) fn new() -> Self { + Elapsed(()) + } +} + +impl fmt::Display for Elapsed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "deadline has elapsed".fmt(fmt) + } +} + +impl std::error::Error for Elapsed {} + +impl From<Elapsed> for std::io::Error { + fn from(_err: Elapsed) -> std::io::Error { + std::io::ErrorKind::TimedOut.into() + } +} diff --git a/src/time/instant.rs b/src/time/instant.rs index f2cb4bc..f4d6eac 100644 --- a/src/time/instant.rs +++ b/src/time/instant.rs @@ -4,8 +4,32 @@ use std::fmt; use std::ops; use std::time::Duration; -/// A measurement of the system clock, useful for talking to -/// external entities like the file system or other processes. +/// A measurement of a monotonically nondecreasing clock. +/// Opaque and useful only with `Duration`. +/// +/// Instants are always guaranteed to be no less than any previously measured +/// instant when created, and are often useful for tasks such as measuring +/// benchmarks or timing how long an operation takes. +/// +/// Note, however, that instants are not guaranteed to be **steady**. In other +/// words, each tick of the underlying clock may not be the same length (e.g. +/// some seconds may be longer than others). An instant may jump forwards or +/// experience time dilation (slow down or speed up), but it will never go +/// backwards. +/// +/// Instants are opaque types that can only be compared to one another. There is +/// no method to get "the number of seconds" from an instant. Instead, it only +/// allows measuring the duration between two instants (or comparing two +/// instants). +/// +/// The size of an `Instant` struct may vary depending on the target operating +/// system. +/// +/// # Note +/// +/// This type wraps the inner `std` variant and is used to align the Tokio +/// clock for uses of `now()`. This can be useful for testing where you can +/// take advantage of `time::pause()` and `time::advance()`. #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct Instant { std: std::time::Instant, @@ -50,12 +74,12 @@ impl Instant { /// # Examples /// /// ``` - /// use tokio::time::{Duration, Instant, delay_for}; + /// use tokio::time::{Duration, Instant, sleep}; /// /// #[tokio::main] /// async fn main() { /// let now = Instant::now(); - /// delay_for(Duration::new(1, 0)).await; + /// sleep(Duration::new(1, 0)).await; /// let new_now = Instant::now(); /// println!("{:?}", new_now.checked_duration_since(now)); /// println!("{:?}", now.checked_duration_since(new_now)); // None @@ -71,12 +95,12 @@ impl Instant { /// # Examples /// /// ``` - /// use tokio::time::{Duration, Instant, delay_for}; + /// use tokio::time::{Duration, Instant, sleep}; /// /// #[tokio::main] /// async fn main() { /// let now = Instant::now(); - /// delay_for(Duration::new(1, 0)).await; + /// sleep(Duration::new(1, 0)).await; /// let new_now = Instant::now(); /// println!("{:?}", new_now.saturating_duration_since(now)); /// println!("{:?}", now.saturating_duration_since(new_now)); // 0ns @@ -97,13 +121,13 @@ impl Instant { /// # Examples /// /// ``` - /// use tokio::time::{Duration, Instant, delay_for}; + /// use tokio::time::{Duration, Instant, sleep}; /// /// #[tokio::main] /// async fn main() { /// let instant = Instant::now(); /// let three_secs = Duration::from_secs(3); - /// delay_for(three_secs).await; + /// sleep(three_secs).await; /// assert!(instant.elapsed() >= three_secs); /// } /// ``` diff --git a/src/time/interval.rs b/src/time/interval.rs index 1fa21e6..c7c58e1 100644 --- a/src/time/interval.rs +++ b/src/time/interval.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::time::{delay_until, Delay, Duration, Instant}; +use crate::time::{sleep_until, Duration, Instant, Sleep}; use std::future::Future; use std::pin::Pin; @@ -36,12 +36,12 @@ use std::task::{Context, Poll}; /// /// A simple example using `interval` to execute a task every two seconds. /// -/// The difference between `interval` and [`delay_for`] is that an `interval` +/// The difference between `interval` and [`sleep`] is that an `interval` /// measures the time since the last tick, which means that `.tick().await` /// may wait for a shorter time than the duration specified for the interval /// if some time has passed between calls to `.tick().await`. /// -/// If the tick in the example below was replaced with [`delay_for`], the task +/// If the tick in the example below was replaced with [`sleep`], the task /// would only be executed once every three seconds, and not every two /// seconds. /// @@ -50,7 +50,7 @@ use std::task::{Context, Poll}; /// /// async fn task_that_takes_a_second() { /// println!("hello"); -/// time::delay_for(time::Duration::from_secs(1)).await +/// time::sleep(time::Duration::from_secs(1)).await /// } /// /// #[tokio::main] @@ -63,7 +63,7 @@ use std::task::{Context, Poll}; /// } /// ``` /// -/// [`delay_for`]: crate::time::delay_for() +/// [`sleep`]: crate::time::sleep() pub fn interval(period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); @@ -71,7 +71,7 @@ pub fn interval(period: Duration) -> Interval { } /// Creates new `Interval` that yields with interval of `period` with the -/// first tick completing at `at`. +/// first tick completing at `start`. /// /// An interval will tick indefinitely. At any time, the `Interval` value can be /// dropped. This cancels the interval. @@ -101,24 +101,28 @@ pub fn interval_at(start: Instant, period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); Interval { - delay: delay_until(start), + delay: sleep_until(start), period, } } /// Stream returned by [`interval`](interval) and [`interval_at`](interval_at). +/// +/// This type only implements the [`Stream`] trait if the "stream" feature is +/// enabled. +/// +/// [`Stream`]: trait@crate::stream::Stream #[derive(Debug)] pub struct Interval { /// Future that completes the next time the `Interval` yields a value. - delay: Delay, + delay: Sleep, /// The duration between values yielded by `Interval`. period: Duration, } impl Interval { - #[doc(hidden)] // TODO: document - pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { + fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { // Wait for the delay to be done ready!(Pin::new(&mut self.delay).poll(cx)); @@ -154,7 +158,6 @@ impl Interval { /// // approximately 20ms have elapsed. /// } /// ``` - #[allow(clippy::should_implement_trait)] // TODO: rename (tokio-rs/tokio#1261) pub async fn tick(&mut self) -> Instant { poll_fn(|cx| self.poll_tick(cx)).await } diff --git a/src/time/mod.rs b/src/time/mod.rs index c532b2c..29af717 100644 --- a/src/time/mod.rs +++ b/src/time/mod.rs @@ -3,7 +3,7 @@ //! This module provides a number of types for executing code after a set period //! of time. //! -//! * `Delay` is a future that does no work and completes at a specific `Instant` +//! * `Sleep` is a future that does no work and completes at a specific `Instant` //! in time. //! //! * `Interval` is a stream yielding a value at a fixed period. It is @@ -14,9 +14,6 @@ //! of time it is allowed to execute. If the future or stream does not //! complete in time, then it is canceled and an error is returned. //! -//! * `DelayQueue`: A queue where items are returned once the requested delay -//! has expired. -//! //! These types are sufficient for handling a large number of scenarios //! involving time. //! @@ -27,14 +24,14 @@ //! Wait 100ms and print "100 ms have elapsed" //! //! ``` -//! use tokio::time::delay_for; +//! use tokio::time::sleep; //! //! use std::time::Duration; //! //! //! #[tokio::main] //! async fn main() { -//! delay_for(Duration::from_millis(100)).await; +//! sleep(Duration::from_millis(100)).await; //! println!("100 ms have elapsed"); //! } //! ``` @@ -61,12 +58,12 @@ //! //! A simple example using [`interval`] to execute a task every two seconds. //! -//! The difference between [`interval`] and [`delay_for`] is that an +//! The difference between [`interval`] and [`sleep`] is that an //! [`interval`] measures the time since the last tick, which means that //! `.tick().await` may wait for a shorter time than the duration specified //! for the interval if some time has passed between calls to `.tick().await`. //! -//! If the tick in the example below was replaced with [`delay_for`], the task +//! If the tick in the example below was replaced with [`sleep`], the task //! would only be executed once every three seconds, and not every two //! seconds. //! @@ -75,7 +72,7 @@ //! //! async fn task_that_takes_a_second() { //! println!("hello"); -//! time::delay_for(time::Duration::from_secs(1)).await +//! time::sleep(time::Duration::from_secs(1)).await //! } //! //! #[tokio::main] @@ -88,7 +85,7 @@ //! } //! ``` //! -//! [`delay_for`]: crate::time::delay_for() +//! [`sleep`]: crate::time::sleep() //! [`interval`]: crate::time::interval() mod clock; @@ -96,17 +93,12 @@ pub(crate) use self::clock::Clock; #[cfg(feature = "test-util")] pub use clock::{advance, pause, resume}; -pub mod delay_queue; -#[doc(inline)] -pub use delay_queue::DelayQueue; - -mod delay; -pub use delay::{delay_for, delay_until, Delay}; +mod sleep; +pub use sleep::{sleep, sleep_until, Sleep}; pub(crate) mod driver; -mod error; -pub use error::Error; +pub mod error; mod instant; pub use self::instant::Instant; @@ -116,12 +108,7 @@ pub use interval::{interval, interval_at, Interval}; mod timeout; #[doc(inline)] -pub use timeout::{timeout, timeout_at, Elapsed, Timeout}; - -cfg_stream! { - mod throttle; - pub use throttle::{throttle, Throttle}; -} +pub use timeout::{timeout, timeout_at, Timeout}; mod wheel; @@ -130,6 +117,7 @@ mod wheel; mod tests; // Re-export for convenience +#[doc(no_inline)] pub use std::time::Duration; // ===== Internal utils ===== diff --git a/src/time/delay.rs b/src/time/sleep.rs index 744c7e1..d3234a1 100644 --- a/src/time/delay.rs +++ b/src/time/sleep.rs @@ -1,31 +1,31 @@ -use crate::time::driver::Registration; -use crate::time::{Duration, Instant}; +use crate::time::driver::{Entry, Handle}; +use crate::time::{error::Error, Duration, Instant}; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; use std::task::{self, Poll}; /// Waits until `deadline` is reached. /// -/// No work is performed while awaiting on the delay to complete. The delay +/// No work is performed while awaiting on the sleep future to complete. `Sleep` /// operates at millisecond granularity and should not be used for tasks that /// require high-resolution timers. /// /// # Cancellation /// -/// Canceling a delay is done by dropping the returned future. No additional +/// Canceling a sleep instance is done by dropping the returned future. No additional /// cleanup work is required. -pub fn delay_until(deadline: Instant) -> Delay { - let registration = Registration::new(deadline, Duration::from_millis(0)); - Delay { registration } +pub fn sleep_until(deadline: Instant) -> Sleep { + Sleep::new_timeout(deadline, Duration::from_millis(0)) } /// Waits until `duration` has elapsed. /// -/// Equivalent to `delay_until(Instant::now() + duration)`. An asynchronous +/// Equivalent to `sleep_until(Instant::now() + duration)`. An asynchronous /// analog to `std::thread::sleep`. /// -/// No work is performed while awaiting on the delay to complete. The delay +/// No work is performed while awaiting on the sleep future to complete. `Sleep` /// operates at millisecond granularity and should not be used for tasks that /// require high-resolution timers. /// @@ -33,7 +33,7 @@ pub fn delay_until(deadline: Instant) -> Delay { /// /// # Cancellation /// -/// Canceling a delay is done by dropping the returned future. No additional +/// Canceling a sleep instance is done by dropping the returned future. No additional /// cleanup work is required. /// /// # Examples @@ -41,78 +41,99 @@ pub fn delay_until(deadline: Instant) -> Delay { /// Wait 100ms and print "100 ms have elapsed". /// /// ``` -/// use tokio::time::{delay_for, Duration}; +/// use tokio::time::{sleep, Duration}; /// /// #[tokio::main] /// async fn main() { -/// delay_for(Duration::from_millis(100)).await; +/// sleep(Duration::from_millis(100)).await; /// println!("100 ms have elapsed"); /// } /// ``` /// /// [`interval`]: crate::time::interval() -#[cfg_attr(docsrs, doc(alias = "sleep"))] -pub fn delay_for(duration: Duration) -> Delay { - delay_until(Instant::now() + duration) +pub fn sleep(duration: Duration) -> Sleep { + sleep_until(Instant::now() + duration) } -/// Future returned by [`delay_until`](delay_until) and -/// [`delay_for`](delay_for). +/// Future returned by [`sleep`](sleep) and +/// [`sleep_until`](sleep_until). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Delay { - /// The link between the `Delay` instance and the timer that drives it. +pub struct Sleep { + /// The link between the `Sleep` instance and the timer that drives it. /// /// This also stores the `deadline` value. - registration: Registration, + entry: Arc<Entry>, } -impl Delay { - pub(crate) fn new_timeout(deadline: Instant, duration: Duration) -> Delay { - let registration = Registration::new(deadline, duration); - Delay { registration } +impl Sleep { + pub(crate) fn new_timeout(deadline: Instant, duration: Duration) -> Sleep { + let handle = Handle::current(); + let entry = Entry::new(&handle, deadline, duration); + + Sleep { entry } } /// Returns the instant at which the future will complete. pub fn deadline(&self) -> Instant { - self.registration.deadline() + self.entry.time_ref().deadline } - /// Returns `true` if the `Delay` has elapsed + /// Returns `true` if `Sleep` has elapsed. /// - /// A `Delay` is elapsed when the requested duration has elapsed. + /// A `Sleep` instance is elapsed when the requested duration has elapsed. pub fn is_elapsed(&self) -> bool { - self.registration.is_elapsed() + self.entry.is_elapsed() } - /// Resets the `Delay` instance to a new deadline. + /// Resets the `Sleep` instance to a new deadline. /// - /// Calling this function allows changing the instant at which the `Delay` + /// Calling this function allows changing the instant at which the `Sleep` /// future completes without having to create new associated state. /// /// This function can be called both before and after the future has /// completed. pub fn reset(&mut self, deadline: Instant) { - self.registration.reset(deadline); + unsafe { + self.entry.time_mut().deadline = deadline; + } + + Entry::reset(&mut self.entry); + } + + fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + + self.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }) } } -impl Future for Delay { +impl Future for Sleep { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { // `poll_elapsed` can return an error in two cases: // - // - AtCapacity: this is a pathlogical case where far too many - // delays have been scheduled. + // - AtCapacity: this is a pathological case where far too many + // sleep instances have been scheduled. // - Shutdown: No timer has been setup, which is a mis-use error. // // Both cases are extremely rare, and pretty accurately fit into // "logic errors", so we just panic in this case. A user couldn't // really do much better if we passed the error onwards. - match ready!(self.registration.poll_elapsed(cx)) { + match ready!(self.poll_elapsed(cx)) { Ok(()) => Poll::Ready(()), Err(e) => panic!("timer error: {}", e), } } } + +impl Drop for Sleep { + fn drop(&mut self) { + Entry::cancel(&self.entry); + } +} diff --git a/src/time/tests/mod.rs b/src/time/tests/mod.rs index 4710d47..fae67da 100644 --- a/src/time/tests/mod.rs +++ b/src/time/tests/mod.rs @@ -1,4 +1,4 @@ -mod test_delay; +mod test_sleep; use crate::time::{self, Instant}; use std::time::Duration; @@ -8,15 +8,15 @@ fn assert_sync<T: Sync>() {} #[test] fn registration_is_send_and_sync() { - use crate::time::driver::Registration; + use crate::time::sleep::Sleep; - assert_send::<Registration>(); - assert_sync::<Registration>(); + assert_send::<Sleep>(); + assert_sync::<Sleep>(); } #[test] #[should_panic] -fn delay_is_eager() { +fn sleep_is_eager() { let when = Instant::now() + Duration::from_millis(100); - let _ = time::delay_until(when); + let _ = time::sleep_until(when); } diff --git a/src/time/tests/test_delay.rs b/src/time/tests/test_sleep.rs index f843434..c8d931a 100644 --- a/src/time/tests/test_delay.rs +++ b/src/time/tests/test_sleep.rs @@ -1,5 +1,3 @@ -#![warn(rust_2018_idioms)] - use crate::park::{Park, Unpark}; use crate::time::driver::{Driver, Entry, Handle}; use crate::time::Clock; @@ -27,12 +25,12 @@ fn frozen_utility_returns_correct_advanced_duration() { } #[test] -fn immediate_delay() { +fn immediate_sleep() { let (mut driver, clock, handle) = setup(); let start = clock.now(); let when = clock.now(); - let mut e = task::spawn(delay_until(&handle, when)); + let mut e = task::spawn(sleep_until(&handle, when)); assert_ready_ok!(poll!(e)); @@ -43,15 +41,15 @@ fn immediate_delay() { } #[test] -fn delayed_delay_level_0() { +fn delayed_sleep_level_0() { let (mut driver, clock, handle) = setup(); let start = clock.now(); for &i in &[1, 10, 60] { - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, start + ms(i))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, start + ms(i))); - // The delay has not elapsed. + // The sleep instance has not elapsed. assert_pending!(poll!(e)); assert_ok!(driver.park()); @@ -62,13 +60,13 @@ fn delayed_delay_level_0() { } #[test] -fn sub_ms_delayed_delay() { +fn sub_ms_delayed_sleep() { let (mut driver, clock, handle) = setup(); for _ in 0..5 { let deadline = clock.now() + ms(1) + Duration::new(0, 1); - let mut e = task::spawn(delay_until(&handle, deadline)); + let mut e = task::spawn(sleep_until(&handle, deadline)); assert_pending!(poll!(e)); @@ -82,14 +80,14 @@ fn sub_ms_delayed_delay() { } #[test] -fn delayed_delay_wrapping_level_0() { +fn delayed_sleep_wrapping_level_0() { let (mut driver, clock, handle) = setup(); let start = clock.now(); assert_ok!(driver.park_timeout(ms(5))); assert_eq!(clock.now() - start, ms(5)); - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(60))); + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(60))); assert_pending!(poll!(e)); @@ -108,15 +106,15 @@ fn timer_wrapping_with_higher_levels() { let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Set delay to hit level 1 - let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(64))); + // Set sleep to hit level 1 + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(64))); assert_pending!(poll!(e1)); // Turn a bit assert_ok!(driver.park_timeout(ms(5))); // Set timeout such that it will hit level 0, but wrap - let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(60))); + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(60))); assert_pending!(poll!(e2)); // This should result in s1 firing @@ -133,14 +131,14 @@ fn timer_wrapping_with_higher_levels() { } #[test] -fn delay_with_deadline_in_past() { +fn sleep_with_deadline_in_past() { let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create `Delay` that elapsed immediately. - let mut e = task::spawn(delay_until(&handle, clock.now() - ms(100))); + // Create `Sleep` that elapsed immediately. + let mut e = task::spawn(sleep_until(&handle, clock.now() - ms(100))); - // Even though the delay expires in the past, it is not ready yet + // Even though the `Sleep` expires in the past, it is not ready yet // because the timer must observe it. assert_ready_ok!(poll!(e)); @@ -152,37 +150,37 @@ fn delay_with_deadline_in_past() { } #[test] -fn delayed_delay_level_1() { +fn delayed_sleep_level_1() { let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(234))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234))); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); // Turn the timer, this will wake up to cascade the timer down. assert_ok!(driver.park_timeout(ms(1000))); assert_eq!(clock.now() - start, ms(192)); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); // Turn the timer again assert_ok!(driver.park_timeout(ms(1000))); assert_eq!(clock.now() - start, ms(234)); - // The delay has elapsed. + // The sleep has elapsed. assert_ready_ok!(poll!(e)); let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(234))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234))); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); // Turn the timer with a smaller timeout than the cascade. @@ -195,14 +193,14 @@ fn delayed_delay_level_1() { assert_ok!(driver.park_timeout(ms(1000))); assert_eq!(clock.now() - start, ms(192)); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); // Turn the timer again assert_ok!(driver.park_timeout(ms(1000))); assert_eq!(clock.now() - start, ms(234)); - // The delay has elapsed. + // The sleep has elapsed. assert_ready_ok!(poll!(e)); } @@ -211,22 +209,22 @@ fn concurrently_set_two_timers_second_one_shorter() { let (mut driver, clock, handle) = setup(); let start = clock.now(); - let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(500))); - let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(200))); + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(500))); + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(200))); - // The delay has not elapsed + // The sleep has not elapsed assert_pending!(poll!(e1)); assert_pending!(poll!(e2)); - // Delay until a cascade + // Sleep until a cascade assert_ok!(driver.park()); assert_eq!(clock.now() - start, ms(192)); - // Delay until the second timer. + // Sleep until the second timer. assert_ok!(driver.park()); assert_eq!(clock.now() - start, ms(200)); - // The shorter delay fires + // The shorter sleep fires assert_ready_ok!(poll!(e2)); assert_pending!(poll!(e1)); @@ -235,7 +233,7 @@ fn concurrently_set_two_timers_second_one_shorter() { assert_pending!(poll!(e1)); - // Turn again, this time the time will advance to the second delay + // Turn again, this time the time will advance to the second sleep assert_ok!(driver.park()); assert_eq!(clock.now() - start, ms(500)); @@ -243,37 +241,37 @@ fn concurrently_set_two_timers_second_one_shorter() { } #[test] -fn short_delay() { +fn short_sleep() { let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(1))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1))); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); // Turn the timer, but not enough time will go by. assert_ok!(driver.park()); - // The delay has elapsed. + // The sleep has elapsed. assert_ready_ok!(poll!(e)); - // The time has advanced to the point of the delay elapsing. + // The time has advanced to the point of the sleep elapsing. assert_eq!(clock.now() - start, ms(1)); } #[test] -fn sorta_long_delay_until() { +fn sorta_long_sleep_until() { const MIN_5: u64 = 5 * 60 * 1000; let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(MIN_5))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MIN_5))); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); let cascades = &[262_144, 262_144 + 9 * 4096, 262_144 + 9 * 4096 + 15 * 64]; @@ -288,21 +286,21 @@ fn sorta_long_delay_until() { assert_ok!(driver.park()); assert_eq!(clock.now() - start, ms(MIN_5)); - // The delay has elapsed. + // The sleep has elapsed. assert_ready_ok!(poll!(e)); } #[test] -fn very_long_delay() { +fn very_long_sleep() { const MO_5: u64 = 5 * 30 * 24 * 60 * 60 * 1000; let (mut driver, clock, handle) = setup(); let start = clock.now(); - // Create a `Delay` that elapses in the future - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(MO_5))); + // Create a `Sleep` that elapses in the future + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MO_5))); - // The delay has not elapsed. + // The sleep has not elapsed. assert_pending!(poll!(e)); let cascades = &[ @@ -322,10 +320,10 @@ fn very_long_delay() { // Turn the timer, but not enough time will go by. assert_ok!(driver.park()); - // The time has advanced to the point of the delay elapsing. + // The time has advanced to the point of the sleep elapsing. assert_eq!(clock.now() - start, ms(MO_5)); - // The delay has elapsed. + // The sleep has elapsed. assert_ready_ok!(poll!(e)); } @@ -353,6 +351,8 @@ fn unpark_is_delayed() { self.0.advance(ms(436)); Ok(()) } + + fn shutdown(&mut self) {} } impl Unpark for MockUnpark { @@ -365,9 +365,9 @@ fn unpark_is_delayed() { let mut driver = Driver::new(MockPark(clock.clone()), clock.clone()); let handle = driver.handle(); - let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(100))); - let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(101))); - let mut e3 = task::spawn(delay_until(&handle, clock.now() + ms(200))); + let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(100))); + let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(101))); + let mut e3 = task::spawn(sleep_until(&handle, clock.now() + ms(200))); assert_pending!(poll!(e1)); assert_pending!(poll!(e2)); @@ -394,7 +394,7 @@ fn set_timeout_at_deadline_greater_than_max_timer() { assert_ok!(driver.park_timeout(ms(YR_1))); } - let mut e = task::spawn(delay_until(&handle, clock.now() + ms(1))); + let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1))); assert_pending!(poll!(e)); assert_ok!(driver.park_timeout(ms(1000))); @@ -412,7 +412,7 @@ fn setup() -> (Driver<MockPark>, Clock, Handle) { (driver, clock, handle) } -fn delay_until(handle: &Handle, when: Instant) -> Arc<Entry> { +fn sleep_until(handle: &Handle, when: Instant) -> Arc<Entry> { Entry::new(&handle, when, ms(0)) } @@ -436,6 +436,8 @@ impl Park for MockPark { self.0.advance(duration); Ok(()) } + + fn shutdown(&mut self) {} } impl Unpark for MockUnpark { diff --git a/src/time/timeout.rs b/src/time/timeout.rs index efc3dc5..cf09b07 100644 --- a/src/time/timeout.rs +++ b/src/time/timeout.rs @@ -4,10 +4,9 @@ //! //! [`Timeout`]: struct@Timeout -use crate::time::{delay_until, Delay, Duration, Instant}; +use crate::time::{error::Elapsed, sleep_until, Duration, Instant, Sleep}; use pin_project_lite::pin_project; -use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{self, Poll}; @@ -50,7 +49,7 @@ pub fn timeout<T>(duration: Duration, future: T) -> Timeout<T> where T: Future, { - let delay = Delay::new_timeout(Instant::now() + duration, duration); + let delay = Sleep::new_timeout(Instant::now() + duration, duration); Timeout::new_with_delay(future, delay) } @@ -92,7 +91,7 @@ pub fn timeout_at<T>(deadline: Instant, future: T) -> Timeout<T> where T: Future, { - let delay = delay_until(deadline); + let delay = sleep_until(deadline); Timeout { value: future, @@ -108,24 +107,12 @@ pin_project! { #[pin] value: T, #[pin] - delay: Delay, - } -} - -/// Error returned by `Timeout`. -#[derive(Debug, PartialEq)] -pub struct Elapsed(()); - -impl Elapsed { - // Used on StreamExt::timeout - #[allow(unused)] - pub(crate) fn new() -> Self { - Elapsed(()) + delay: Sleep, } } impl<T> Timeout<T> { - pub(crate) fn new_with_delay(value: T, delay: Delay) -> Timeout<T> { + pub(crate) fn new_with_delay(value: T, delay: Sleep) -> Timeout<T> { Timeout { value, delay } } @@ -161,24 +148,8 @@ where // Now check the timer match me.delay.poll(cx) { - Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))), + Poll::Ready(()) => Poll::Ready(Err(Elapsed::new())), Poll::Pending => Poll::Pending, } } } - -// ===== impl Elapsed ===== - -impl fmt::Display for Elapsed { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - "deadline has elapsed".fmt(fmt) - } -} - -impl std::error::Error for Elapsed {} - -impl From<Elapsed> for std::io::Error { - fn from(_err: Elapsed) -> std::io::Error { - std::io::ErrorKind::TimedOut.into() - } -} diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs index 49f9bfb..d51d26a 100644 --- a/src/time/wheel/level.rs +++ b/src/time/wheel/level.rs @@ -1,9 +1,10 @@ +use super::{Item, OwnedItem}; use crate::time::wheel::Stack; use std::fmt; /// Wheel for a single level in the timer. This wheel contains 64 slots. -pub(crate) struct Level<T> { +pub(crate) struct Level { level: usize, /// Bit field tracking which slots currently contain entries. @@ -16,7 +17,7 @@ pub(crate) struct Level<T> { occupied: u64, /// Slots - slot: [T; LEVEL_MULT], + slot: [Stack; LEVEL_MULT], } /// Indicates when a slot must be processed next. @@ -37,87 +38,90 @@ pub(crate) struct Expiration { /// Being a power of 2 is very important. const LEVEL_MULT: usize = 64; -impl<T: Stack> Level<T> { - pub(crate) fn new(level: usize) -> Level<T> { - // Rust's derived implementations for arrays require that the value - // contained by the array be `Copy`. So, here we have to manually - // initialize every single slot. - macro_rules! s { - () => { - T::default() - }; - }; +impl Level { + pub(crate) fn new(level: usize) -> Level { + // A value has to be Copy in order to use syntax like: + // let stack = Stack::default(); + // ... + // slots: [stack; 64], + // + // Alternatively, since Stack is Default one can + // use syntax like: + // let slots: [Stack; 64] = Default::default(); + // + // However, that is only supported for arrays of size + // 32 or fewer. So in our case we have to explicitly + // invoke the constructor for each array element. + let ctor = Stack::default; Level { level, occupied: 0, slot: [ - // It does not look like the necessary traits are - // derived for [T; 64]. - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), - s!(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), + ctor(), ], } } @@ -173,17 +177,17 @@ impl<T: Stack> Level<T> { Some(slot) } - pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { + pub(crate) fn add_entry(&mut self, when: u64, item: OwnedItem) { let slot = slot_for(when, self.level); - self.slot[slot].push(item, store); + self.slot[slot].push(item); self.occupied |= occupied_bit(slot); } - pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { + pub(crate) fn remove_entry(&mut self, when: u64, item: &Item) { let slot = slot_for(when, self.level); - self.slot[slot].remove(item, store); + self.slot[slot].remove(item); if self.slot[slot].is_empty() { // The bit is currently set @@ -194,8 +198,8 @@ impl<T: Stack> Level<T> { } } - pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> { - let ret = self.slot[slot].pop(store); + pub(crate) fn pop_entry_slot(&mut self, slot: usize) -> Option<OwnedItem> { + let ret = self.slot[slot].pop(); if ret.is_some() && self.slot[slot].is_empty() { // The bit is currently set @@ -208,7 +212,7 @@ impl<T: Stack> Level<T> { } } -impl<T> fmt::Debug for Level<T> { +impl fmt::Debug for Level { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Level") .field("occupied", &self.occupied) diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs index a2ef27f..85ed2f1 100644 --- a/src/time/wheel/mod.rs +++ b/src/time/wheel/mod.rs @@ -1,3 +1,5 @@ +use crate::time::{driver::Entry, error::InsertError}; + mod level; pub(crate) use self::level::Expiration; use self::level::Level; @@ -5,9 +7,12 @@ use self::level::Level; mod stack; pub(crate) use self::stack::Stack; -use std::borrow::Borrow; +use std::sync::Arc; use std::usize; +pub(super) type Item = Entry; +pub(super) type OwnedItem = Arc<Item>; + /// Timing wheel implementation. /// /// This type provides the hashed timing wheel implementation that backs `Timer` @@ -20,7 +25,7 @@ use std::usize; /// /// See `Timer` documentation for some implementation notes. #[derive(Debug)] -pub(crate) struct Wheel<T> { +pub(crate) struct Wheel { /// The number of milliseconds elapsed since the wheel started. elapsed: u64, @@ -34,7 +39,7 @@ pub(crate) struct Wheel<T> { /// * ~ 4 min slots / ~ 4 hr range /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range - levels: Vec<Level<T>>, + levels: Vec<Level>, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots @@ -42,28 +47,12 @@ pub(crate) struct Wheel<T> { /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; -/// The maximum duration of a delay +/// The maximum duration of a `Sleep` const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; -#[derive(Debug)] -pub(crate) enum InsertError { - Elapsed, - Invalid, -} - -/// Poll expirations from the wheel -#[derive(Debug, Default)] -pub(crate) struct Poll { - now: u64, - expiration: Option<Expiration>, -} - -impl<T> Wheel<T> -where - T: Stack, -{ +impl Wheel { /// Create a new timing wheel - pub(crate) fn new() -> Wheel<T> { + pub(crate) fn new() -> Wheel { let levels = (0..NUM_LEVELS).map(Level::new).collect(); Wheel { elapsed: 0, levels } @@ -99,9 +88,8 @@ where pub(crate) fn insert( &mut self, when: u64, - item: T::Owned, - store: &mut T::Store, - ) -> Result<(), (T::Owned, InsertError)> { + item: OwnedItem, + ) -> Result<(), (OwnedItem, InsertError)> { if when <= self.elapsed { return Err((item, InsertError::Elapsed)); } else if when - self.elapsed > MAX_DURATION { @@ -111,7 +99,7 @@ where // Get the level at which the entry should be stored let level = self.level_for(when); - self.levels[level].add_entry(when, item, store); + self.levels[level].add_entry(when, item); debug_assert!({ self.levels[level] @@ -124,11 +112,11 @@ where } /// Remove `item` from thee timing wheel. - pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { - let when = T::when(item, store); + pub(crate) fn remove(&mut self, item: &Item) { + let when = item.when(); let level = self.level_for(when); - self.levels[level].remove_entry(when, item, store); + self.levels[level].remove_entry(when, item); } /// Instant at which to poll @@ -136,33 +124,35 @@ where self.next_expiration().map(|expiration| expiration.deadline) } - pub(crate) fn poll(&mut self, poll: &mut Poll, store: &mut T::Store) -> Option<T::Owned> { + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn poll(&mut self, now: u64) -> Option<OwnedItem> { loop { - if poll.expiration.is_none() { - poll.expiration = self.next_expiration().and_then(|expiration| { - if expiration.deadline > poll.now { - None - } else { - Some(expiration) - } - }); - } + // under what circumstances is poll.expiration Some vs. None? + let expiration = self.next_expiration().and_then(|expiration| { + if expiration.deadline > now { + None + } else { + Some(expiration) + } + }); - match poll.expiration { + match expiration { Some(ref expiration) => { - if let Some(item) = self.poll_expiration(expiration, store) { + if let Some(item) = self.poll_expiration(expiration) { return Some(item); } self.set_elapsed(expiration.deadline); } None => { - self.set_elapsed(poll.now); + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); return None; } } - - poll.expiration = None; } } @@ -197,22 +187,22 @@ where res } - pub(crate) fn poll_expiration( - &mut self, - expiration: &Expiration, - store: &mut T::Store, - ) -> Option<T::Owned> { - while let Some(item) = self.pop_entry(expiration, store) { + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// return it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn poll_expiration(&mut self, expiration: &Expiration) -> Option<OwnedItem> { + while let Some(item) = self.pop_entry(expiration) { if expiration.level == 0 { - debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); + debug_assert_eq!(item.when(), expiration.deadline); return Some(item); } else { - let when = T::when(item.borrow(), store); + let when = item.when(); let next_level = expiration.level - 1; - self.levels[next_level].add_entry(when, item, store); + self.levels[next_level].add_entry(when, item); } } @@ -232,8 +222,8 @@ where } } - fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> { - self.levels[expiration.level].pop_entry_slot(expiration.slot, store) + fn pop_entry(&mut self, expiration: &Expiration) -> Option<OwnedItem> { + self.levels[expiration.level].pop_entry_slot(expiration.slot) } fn level_for(&self, when: u64) -> usize { @@ -251,15 +241,6 @@ fn level_for(elapsed: u64, when: u64) -> usize { significant / 6 } -impl Poll { - pub(crate) fn new(now: u64) -> Poll { - Poll { - now, - expiration: None, - } - } -} - #[cfg(all(test, not(loom)))] mod test { use super::*; diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs index 6e55c38..e7ed137 100644 --- a/src/time/wheel/stack.rs +++ b/src/time/wheel/stack.rs @@ -1,26 +1,112 @@ -use std::borrow::Borrow; +use super::{Item, OwnedItem}; +use crate::time::driver::Entry; -/// Abstracts the stack operations needed to track timeouts. -pub(crate) trait Stack: Default { - /// Type of the item stored in the stack - type Owned: Borrow<Self::Borrowed>; +use std::ptr; - /// Borrowed item - type Borrowed; +/// A doubly linked stack +#[derive(Debug)] +pub(crate) struct Stack { + head: Option<OwnedItem>, +} + +impl Default for Stack { + fn default() -> Stack { + Stack { head: None } + } +} + +impl Stack { + pub(crate) fn is_empty(&self) -> bool { + self.head.is_none() + } + + pub(crate) fn push(&mut self, entry: OwnedItem) { + // Get a pointer to the entry to for the prev link + let ptr: *const Entry = &*entry as *const _; + + // Remove the old head entry + let old = self.head.take(); + + unsafe { + // Ensure the entry is not already in a stack. + debug_assert!((*entry.next_stack.get()).is_none()); + debug_assert!((*entry.prev_stack.get()).is_null()); + + if let Some(ref entry) = old.as_ref() { + debug_assert!({ + // The head is not already set to the entry + ptr != &***entry as *const _ + }); + + // Set the previous link on the old head + *entry.prev_stack.get() = ptr; + } + + // Set this entry's next pointer + *entry.next_stack.get() = old; + } + + // Update the head pointer + self.head = Some(entry); + } + + /// Pops an item from the stack + pub(crate) fn pop(&mut self) -> Option<OwnedItem> { + let entry = self.head.take(); + + unsafe { + if let Some(entry) = entry.as_ref() { + self.head = (*entry.next_stack.get()).take(); + + if let Some(entry) = self.head.as_ref() { + *entry.prev_stack.get() = ptr::null(); + } + + *entry.prev_stack.get() = ptr::null(); + } + } + + entry + } + + pub(crate) fn remove(&mut self, entry: &Item) { + unsafe { + // Ensure that the entry is in fact contained by the stack + debug_assert!({ + // This walks the full linked list even if an entry is found. + let mut next = self.head.as_ref(); + let mut contains = false; + + while let Some(n) = next { + if entry as *const _ == &**n as *const _ { + debug_assert!(!contains); + contains = true; + } + + next = (*n.next_stack.get()).as_ref(); + } - /// Item storage, this allows a slab to be used instead of just the heap - type Store; + contains + }); - /// Returns `true` if the stack is empty - fn is_empty(&self) -> bool; + // Unlink `entry` from the next node + let next = (*entry.next_stack.get()).take(); - /// Push an item onto the stack - fn push(&mut self, item: Self::Owned, store: &mut Self::Store); + if let Some(next) = next.as_ref() { + (*next.prev_stack.get()) = *entry.prev_stack.get(); + } - /// Pop an item from the stack - fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>; + // Unlink `entry` from the prev node - fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); + if let Some(prev) = (*entry.prev_stack.get()).as_ref() { + *prev.next_stack.get() = next; + } else { + // It is the head + self.head = next; + } - fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; + // Unset the prev pointer + *entry.prev_stack.get() = ptr::null(); + } + } } diff --git a/src/util/bit.rs b/src/util/bit.rs index e61ac21..392a0e8 100644 --- a/src/util/bit.rs +++ b/src/util/bit.rs @@ -1,22 +1,12 @@ use std::fmt; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq)] pub(crate) struct Pack { mask: usize, shift: u32, } impl Pack { - /// Value is packed in the `width` most-significant bits. - pub(crate) const fn most_significant(width: u32) -> Pack { - let mask = mask_for(width).reverse_bits(); - - Pack { - mask, - shift: mask.trailing_zeros(), - } - } - /// Value is packed in the `width` least-significant bits. pub(crate) const fn least_significant(width: u32) -> Pack { let mask = mask_for(width); @@ -32,12 +22,6 @@ impl Pack { Pack { mask, shift } } - /// Mask used to unpack value - #[cfg(all(test, loom))] - pub(crate) const fn mask(&self) -> usize { - self.mask - } - /// Width, in bits, dedicated to storing the value. pub(crate) const fn width(&self) -> u32 { pointer_width() - (self.mask >> self.shift).leading_zeros() @@ -53,6 +37,14 @@ impl Pack { (base & !self.mask) | (value << self.shift) } + /// Packs the value with `base`, losing any bits of `value` that fit. + /// + /// If `value` is larger than the max value that can be represented by the + /// allotted width, the most significant bits are truncated. + pub(crate) fn pack_lossy(&self, value: usize, base: usize) -> usize { + self.pack(value & self.max_value(), base) + } + pub(crate) fn unpack(&self, src: usize) -> usize { unpack(src, self.mask, self.shift) } diff --git a/src/util/intrusive_double_linked_list.rs b/src/util/intrusive_double_linked_list.rs deleted file mode 100644 index 083fa31..0000000 --- a/src/util/intrusive_double_linked_list.rs +++ /dev/null @@ -1,788 +0,0 @@ -//! An intrusive double linked list of data - -#![allow(dead_code, unreachable_pub)] - -use core::{ - marker::PhantomPinned, - ops::{Deref, DerefMut}, - ptr::NonNull, -}; - -/// A node which carries data of type `T` and is stored in an intrusive list -#[derive(Debug)] -pub struct ListNode<T> { - /// The previous node in the list. `None` if there is no previous node. - prev: Option<NonNull<ListNode<T>>>, - /// The next node in the list. `None` if there is no previous node. - next: Option<NonNull<ListNode<T>>>, - /// The data which is associated to this list item - data: T, - /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since - /// the list semantics require addresses to be stable. - _pin: PhantomPinned, -} - -impl<T> ListNode<T> { - /// Creates a new node with the associated data - pub fn new(data: T) -> ListNode<T> { - Self { - prev: None, - next: None, - data, - _pin: PhantomPinned, - } - } -} - -impl<T> Deref for ListNode<T> { - type Target = T; - - fn deref(&self) -> &T { - &self.data - } -} - -impl<T> DerefMut for ListNode<T> { - fn deref_mut(&mut self) -> &mut T { - &mut self.data - } -} - -/// An intrusive linked list of nodes, where each node carries associated data -/// of type `T`. -#[derive(Debug)] -pub struct LinkedList<T> { - head: Option<NonNull<ListNode<T>>>, - tail: Option<NonNull<ListNode<T>>>, -} - -impl<T> LinkedList<T> { - /// Creates an empty linked list - pub fn new() -> Self { - LinkedList::<T> { - head: None, - tail: None, - } - } - - /// Adds a node at the front of the linked list. - /// Safety: This function is only safe as long as `node` is guaranteed to - /// get removed from the list before it gets moved or dropped. - /// In addition to this `node` may not be added to another other list before - /// it is removed from the current one. - pub unsafe fn add_front(&mut self, node: &mut ListNode<T>) { - node.next = self.head; - node.prev = None; - if let Some(mut head) = self.head { - head.as_mut().prev = Some(node.into()) - }; - self.head = Some(node.into()); - if self.tail.is_none() { - self.tail = Some(node.into()); - } - } - - /// Inserts a node into the list in a way that the list keeps being sorted. - /// Safety: This function is only safe as long as `node` is guaranteed to - /// get removed from the list before it gets moved or dropped. - /// In addition to this `node` may not be added to another other list before - /// it is removed from the current one. - pub unsafe fn add_sorted(&mut self, node: &mut ListNode<T>) - where - T: PartialOrd, - { - if self.head.is_none() { - // First node in the list - self.head = Some(node.into()); - self.tail = Some(node.into()); - return; - } - - let mut prev: Option<NonNull<ListNode<T>>> = None; - let mut current = self.head; - - while let Some(mut current_node) = current { - if node.data < current_node.as_ref().data { - // Need to insert before the current node - current_node.as_mut().prev = Some(node.into()); - match prev { - Some(mut prev) => { - prev.as_mut().next = Some(node.into()); - } - None => { - // We are inserting at the beginning of the list - self.head = Some(node.into()); - } - } - node.next = current; - node.prev = prev; - return; - } - prev = current; - current = current_node.as_ref().next; - } - - // We looped through the whole list and the nodes data is bigger or equal - // than everything we found up to now. - // Insert at the end. Since we checked before that the list isn't empty, - // tail always has a value. - node.prev = self.tail; - node.next = None; - self.tail.as_mut().unwrap().as_mut().next = Some(node.into()); - self.tail = Some(node.into()); - } - - /// Returns the first node in the linked list without removing it from the list - /// The function is only safe as long as valid pointers are stored inside - /// the linked list. - /// The returned pointer is only guaranteed to be valid as long as the list - /// is not mutated - pub fn peek_first(&self) -> Option<&mut ListNode<T>> { - // Safety: When the node was inserted it was promised that it is alive - // until it gets removed from the list. - // The returned node has a pointer which constrains it to the lifetime - // of the list. This is ok, since the Node is supposed to outlive - // its insertion in the list. - unsafe { - self.head - .map(|mut node| &mut *(node.as_mut() as *mut ListNode<T>)) - } - } - - /// Returns the last node in the linked list without removing it from the list - /// The function is only safe as long as valid pointers are stored inside - /// the linked list. - /// The returned pointer is only guaranteed to be valid as long as the list - /// is not mutated - pub fn peek_last(&self) -> Option<&mut ListNode<T>> { - // Safety: When the node was inserted it was promised that it is alive - // until it gets removed from the list. - // The returned node has a pointer which constrains it to the lifetime - // of the list. This is ok, since the Node is supposed to outlive - // its insertion in the list. - unsafe { - self.tail - .map(|mut node| &mut *(node.as_mut() as *mut ListNode<T>)) - } - } - - /// Removes the first node from the linked list - pub fn remove_first(&mut self) -> Option<&mut ListNode<T>> { - #![allow(clippy::debug_assert_with_mut_call)] - - // Safety: When the node was inserted it was promised that it is alive - // until it gets removed from the list - unsafe { - let mut head = self.head?; - self.head = head.as_mut().next; - - let first_ref = head.as_mut(); - match first_ref.next { - None => { - // This was the only node in the list - debug_assert_eq!(Some(first_ref.into()), self.tail); - self.tail = None; - } - Some(mut next) => { - next.as_mut().prev = None; - } - } - - first_ref.prev = None; - first_ref.next = None; - Some(&mut *(first_ref as *mut ListNode<T>)) - } - } - - /// Removes the last node from the linked list and returns it - pub fn remove_last(&mut self) -> Option<&mut ListNode<T>> { - #![allow(clippy::debug_assert_with_mut_call)] - - // Safety: When the node was inserted it was promised that it is alive - // until it gets removed from the list - unsafe { - let mut tail = self.tail?; - self.tail = tail.as_mut().prev; - - let last_ref = tail.as_mut(); - match last_ref.prev { - None => { - // This was the last node in the list - debug_assert_eq!(Some(last_ref.into()), self.head); - self.head = None; - } - Some(mut prev) => { - prev.as_mut().next = None; - } - } - - last_ref.prev = None; - last_ref.next = None; - Some(&mut *(last_ref as *mut ListNode<T>)) - } - } - - /// Returns whether the linked list doesn not contain any node - pub fn is_empty(&self) -> bool { - if self.head.is_some() { - return false; - } - - debug_assert!(self.tail.is_none()); - true - } - - /// Removes the given `node` from the linked list. - /// Returns whether the `node` was removed. - /// It is also only safe if it is known that the `node` is either part of this - /// list, or of no list at all. If `node` is part of another list, the - /// behavior is undefined. - pub unsafe fn remove(&mut self, node: &mut ListNode<T>) -> bool { - #![allow(clippy::debug_assert_with_mut_call)] - - match node.prev { - None => { - // This might be the first node in the list. If it is not, the - // node is not in the list at all. Since our precondition is that - // the node must either be in this list or in no list, we check that - // the node is really in no list. - if self.head != Some(node.into()) { - debug_assert!(node.next.is_none()); - return false; - } - self.head = node.next; - } - Some(mut prev) => { - debug_assert_eq!(prev.as_ref().next, Some(node.into())); - prev.as_mut().next = node.next; - } - } - - match node.next { - None => { - // This must be the last node in our list. Otherwise the list - // is inconsistent. - debug_assert_eq!(self.tail, Some(node.into())); - self.tail = node.prev; - } - Some(mut next) => { - debug_assert_eq!(next.as_mut().prev, Some(node.into())); - next.as_mut().prev = node.prev; - } - } - - node.next = None; - node.prev = None; - - true - } - - /// Drains the list iby calling a callback on each list node - /// - /// The method does not return an iterator since stopping or deferring - /// draining the list is not permitted. If the method would push nodes to - /// an iterator we could not guarantee that the nodes do not get utilized - /// after having been removed from the list anymore. - pub fn drain<F>(&mut self, mut func: F) - where - F: FnMut(&mut ListNode<T>), - { - let mut current = self.head; - self.head = None; - self.tail = None; - - while let Some(mut node) = current { - // Safety: The nodes have not been removed from the list yet and must - // therefore contain valid data. The nodes can also not be added to - // the list again during iteration, since the list is mutably borrowed. - unsafe { - let node_ref = node.as_mut(); - current = node_ref.next; - - node_ref.next = None; - node_ref.prev = None; - - // Note: We do not reset the pointers from the next element in the - // list to the current one since we will iterate over the whole - // list anyway, and therefore clean up all pointers. - - func(node_ref); - } - } - } - - /// Drains the list in reverse order by calling a callback on each list node - /// - /// The method does not return an iterator since stopping or deferring - /// draining the list is not permitted. If the method would push nodes to - /// an iterator we could not guarantee that the nodes do not get utilized - /// after having been removed from the list anymore. - pub fn reverse_drain<F>(&mut self, mut func: F) - where - F: FnMut(&mut ListNode<T>), - { - let mut current = self.tail; - self.head = None; - self.tail = None; - - while let Some(mut node) = current { - // Safety: The nodes have not been removed from the list yet and must - // therefore contain valid data. The nodes can also not be added to - // the list again during iteration, since the list is mutably borrowed. - unsafe { - let node_ref = node.as_mut(); - current = node_ref.prev; - - node_ref.next = None; - node_ref.prev = None; - - // Note: We do not reset the pointers from the next element in the - // list to the current one since we will iterate over the whole - // list anyway, and therefore clean up all pointers. - - func(node_ref); - } - } - } -} - -#[cfg(all(test, feature = "std"))] // Tests make use of Vec at the moment -mod tests { - use super::*; - - fn collect_list<T: Copy>(mut list: LinkedList<T>) -> Vec<T> { - let mut result = Vec::new(); - list.drain(|node| { - result.push(**node); - }); - result - } - - fn collect_reverse_list<T: Copy>(mut list: LinkedList<T>) -> Vec<T> { - let mut result = Vec::new(); - list.reverse_drain(|node| { - result.push(**node); - }); - result - } - - unsafe fn add_nodes(list: &mut LinkedList<i32>, nodes: &mut [&mut ListNode<i32>]) { - for node in nodes.iter_mut() { - list.add_front(node); - } - } - - unsafe fn assert_clean<T>(node: &mut ListNode<T>) { - assert!(node.next.is_none()); - assert!(node.prev.is_none()); - } - - #[test] - fn insert_and_iterate() { - unsafe { - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - let mut setup = |list: &mut LinkedList<i32>| { - assert_eq!(true, list.is_empty()); - list.add_front(&mut c); - assert_eq!(31, **list.peek_first().unwrap()); - assert_eq!(false, list.is_empty()); - list.add_front(&mut b); - assert_eq!(7, **list.peek_first().unwrap()); - list.add_front(&mut a); - assert_eq!(5, **list.peek_first().unwrap()); - }; - - let mut list = LinkedList::new(); - setup(&mut list); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31].to_vec(), items); - - let mut list = LinkedList::new(); - setup(&mut list); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([31, 7, 5].to_vec(), items); - } - } - - #[test] - fn add_sorted() { - unsafe { - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - let mut d = ListNode::new(99); - - let mut list = LinkedList::new(); - list.add_sorted(&mut a); - let items: Vec<i32> = collect_list(list); - assert_eq!([5].to_vec(), items); - - let mut list = LinkedList::new(); - list.add_sorted(&mut a); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); - list.add_sorted(&mut a); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31, 99].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); - list.add_sorted(&mut a); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([99, 31, 7, 5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); - list.add_sorted(&mut b); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31, 99].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); - list.add_sorted(&mut b); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([99, 31, 7, 5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); - list.add_sorted(&mut c); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31, 99].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); - list.add_sorted(&mut c); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([99, 31, 7, 5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - list.add_sorted(&mut d); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31, 99].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - list.add_sorted(&mut d); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([99, 31, 7, 5].to_vec(), items); - } - } - - #[test] - fn drain_and_collect() { - unsafe { - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - - let taken_items: Vec<i32> = collect_list(list); - assert_eq!([5, 7, 31].to_vec(), taken_items); - } - } - - #[test] - fn peek_last() { - unsafe { - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - - let last = list.peek_last(); - assert_eq!(31, **last.unwrap()); - list.remove_last(); - - let last = list.peek_last(); - assert_eq!(7, **last.unwrap()); - list.remove_last(); - - let last = list.peek_last(); - assert_eq!(5, **last.unwrap()); - list.remove_last(); - - let last = list.peek_last(); - assert!(last.is_none()); - } - } - - #[test] - fn remove_first() { - unsafe { - // We iterate forward and backwards through the manipulated lists - // to make sure pointers in both directions are still ok. - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert_eq!([7, 31].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([31, 7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert_eq!([7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert!(items.is_empty()); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut a]); - let removed = list.remove_first().unwrap(); - assert_clean(removed); - assert!(list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert!(items.is_empty()); - } - } - - #[test] - fn remove_last() { - unsafe { - // We iterate forward and backwards through the manipulated lists - // to make sure pointers in both directions are still ok. - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([7, 5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert_eq!([5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(!list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(list.is_empty()); - let items: Vec<i32> = collect_list(list); - assert!(items.is_empty()); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut a]); - let removed = list.remove_last().unwrap(); - assert_clean(removed); - assert!(list.is_empty()); - let items: Vec<i32> = collect_reverse_list(list); - assert!(items.is_empty()); - } - } - - #[test] - fn remove_by_address() { - unsafe { - let mut a = ListNode::new(5); - let mut b = ListNode::new(7); - let mut c = ListNode::new(31); - - { - // Remove first - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut a)); - assert_clean((&mut a).into()); - // a should be no longer there and can't be removed twice - assert_eq!(false, list.remove(&mut a)); - assert_eq!(Some((&mut b).into()), list.head); - assert_eq!(Some((&mut c).into()), b.next); - assert_eq!(Some((&mut b).into()), c.prev); - let items: Vec<i32> = collect_list(list); - assert_eq!([7, 31].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut a)); - assert_clean((&mut a).into()); - // a should be no longer there and can't be removed twice - assert_eq!(false, list.remove(&mut a)); - assert_eq!(Some((&mut c).into()), b.next); - assert_eq!(Some((&mut b).into()), c.prev); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([31, 7].to_vec(), items); - } - - { - // Remove middle - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut b)); - assert_clean((&mut b).into()); - assert_eq!(Some((&mut c).into()), a.next); - assert_eq!(Some((&mut a).into()), c.prev); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 31].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut b)); - assert_clean((&mut b).into()); - assert_eq!(Some((&mut c).into()), a.next); - assert_eq!(Some((&mut a).into()), c.prev); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([31, 5].to_vec(), items); - } - - { - // Remove last - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut c)); - assert_clean((&mut c).into()); - assert!(b.next.is_none()); - assert_eq!(Some((&mut b).into()), list.tail); - let items: Vec<i32> = collect_list(list); - assert_eq!([5, 7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); - assert_eq!(true, list.remove(&mut c)); - assert_clean((&mut c).into()); - assert!(b.next.is_none()); - assert_eq!(Some((&mut b).into()), list.tail); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([7, 5].to_vec(), items); - } - - { - // Remove first of two - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - assert_eq!(true, list.remove(&mut a)); - assert_clean((&mut a).into()); - // a should be no longer there and can't be removed twice - assert_eq!(false, list.remove(&mut a)); - assert_eq!(Some((&mut b).into()), list.head); - assert_eq!(Some((&mut b).into()), list.tail); - assert!(b.next.is_none()); - assert!(b.prev.is_none()); - let items: Vec<i32> = collect_list(list); - assert_eq!([7].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - assert_eq!(true, list.remove(&mut a)); - assert_clean((&mut a).into()); - // a should be no longer there and can't be removed twice - assert_eq!(false, list.remove(&mut a)); - assert_eq!(Some((&mut b).into()), list.head); - assert_eq!(Some((&mut b).into()), list.tail); - assert!(b.next.is_none()); - assert!(b.prev.is_none()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([7].to_vec(), items); - } - - { - // Remove last of two - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - assert_eq!(true, list.remove(&mut b)); - assert_clean((&mut b).into()); - assert_eq!(Some((&mut a).into()), list.head); - assert_eq!(Some((&mut a).into()), list.tail); - assert!(a.next.is_none()); - assert!(a.prev.is_none()); - let items: Vec<i32> = collect_list(list); - assert_eq!([5].to_vec(), items); - - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut b, &mut a]); - assert_eq!(true, list.remove(&mut b)); - assert_clean((&mut b).into()); - assert_eq!(Some((&mut a).into()), list.head); - assert_eq!(Some((&mut a).into()), list.tail); - assert!(a.next.is_none()); - assert!(a.prev.is_none()); - let items: Vec<i32> = collect_reverse_list(list); - assert_eq!([5].to_vec(), items); - } - - { - // Remove last item - let mut list = LinkedList::new(); - add_nodes(&mut list, &mut [&mut a]); - assert_eq!(true, list.remove(&mut a)); - assert_clean((&mut a).into()); - assert!(list.head.is_none()); - assert!(list.tail.is_none()); - let items: Vec<i32> = collect_list(list); - assert!(items.is_empty()); - } - - { - // Remove missing - let mut list = LinkedList::new(); - list.add_front(&mut b); - list.add_front(&mut a); - assert_eq!(false, list.remove(&mut c)); - } - } - } -} diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs index aa3ce77..4681276 100644 --- a/src/util/linked_list.rs +++ b/src/util/linked_list.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "full"), allow(dead_code))] + //! An intrusive double linked list of data //! //! The data structure supports tracking pinned nodes. Most of the data @@ -5,6 +7,7 @@ //! specified node is actually contained by the list. use core::fmt; +use core::marker::PhantomData; use core::mem::ManuallyDrop; use core::ptr::NonNull; @@ -12,16 +15,19 @@ use core::ptr::NonNull; /// /// Currently, the list is not emptied on drop. It is the caller's /// responsibility to ensure the list is empty before dropping it. -pub(crate) struct LinkedList<T: Link> { +pub(crate) struct LinkedList<L, T> { /// Linked list head - head: Option<NonNull<T::Target>>, + head: Option<NonNull<T>>, /// Linked list tail - tail: Option<NonNull<T::Target>>, + tail: Option<NonNull<T>>, + + /// Node type marker. + _marker: PhantomData<*const L>, } -unsafe impl<T: Link> Send for LinkedList<T> where T::Target: Send {} -unsafe impl<T: Link> Sync for LinkedList<T> where T::Target: Sync {} +unsafe impl<L: Link> Send for LinkedList<L, L::Target> where L::Target: Send {} +unsafe impl<L: Link> Sync for LinkedList<L, L::Target> where L::Target: Sync {} /// Defines how a type is tracked within a linked list. /// @@ -66,27 +72,30 @@ unsafe impl<T: Sync> Sync for Pointers<T> {} // ===== impl LinkedList ===== -impl<T: Link> LinkedList<T> { - /// Creates an empty linked list - pub(crate) fn new() -> LinkedList<T> { +impl<L, T> LinkedList<L, T> { + /// Creates an empty linked list. + pub(crate) const fn new() -> LinkedList<L, T> { LinkedList { head: None, tail: None, + _marker: PhantomData, } } +} +impl<L: Link> LinkedList<L, L::Target> { /// Adds an element first in the list. - pub(crate) fn push_front(&mut self, val: T::Handle) { + pub(crate) fn push_front(&mut self, val: L::Handle) { // The value should not be dropped, it is being inserted into the list let val = ManuallyDrop::new(val); - let ptr = T::as_raw(&*val); + let ptr = L::as_raw(&*val); assert_ne!(self.head, Some(ptr)); unsafe { - T::pointers(ptr).as_mut().next = self.head; - T::pointers(ptr).as_mut().prev = None; + L::pointers(ptr).as_mut().next = self.head; + L::pointers(ptr).as_mut().prev = None; if let Some(head) = self.head { - T::pointers(head).as_mut().prev = Some(ptr); + L::pointers(head).as_mut().prev = Some(ptr); } self.head = Some(ptr); @@ -99,21 +108,21 @@ impl<T: Link> LinkedList<T> { /// Removes the last element from a list and returns it, or None if it is /// empty. - pub(crate) fn pop_back(&mut self) -> Option<T::Handle> { + pub(crate) fn pop_back(&mut self) -> Option<L::Handle> { unsafe { let last = self.tail?; - self.tail = T::pointers(last).as_ref().prev; + self.tail = L::pointers(last).as_ref().prev; - if let Some(prev) = T::pointers(last).as_ref().prev { - T::pointers(prev).as_mut().next = None; + if let Some(prev) = L::pointers(last).as_ref().prev { + L::pointers(prev).as_mut().next = None; } else { self.head = None } - T::pointers(last).as_mut().prev = None; - T::pointers(last).as_mut().next = None; + L::pointers(last).as_mut().prev = None; + L::pointers(last).as_mut().next = None; - Some(T::from_raw(last)) + Some(L::from_raw(last)) } } @@ -133,38 +142,38 @@ impl<T: Link> LinkedList<T> { /// /// The caller **must** ensure that `node` is currently contained by /// `self` or not contained by any other list. - pub(crate) unsafe fn remove(&mut self, node: NonNull<T::Target>) -> Option<T::Handle> { - if let Some(prev) = T::pointers(node).as_ref().prev { - debug_assert_eq!(T::pointers(prev).as_ref().next, Some(node)); - T::pointers(prev).as_mut().next = T::pointers(node).as_ref().next; + pub(crate) unsafe fn remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle> { + if let Some(prev) = L::pointers(node).as_ref().prev { + debug_assert_eq!(L::pointers(prev).as_ref().next, Some(node)); + L::pointers(prev).as_mut().next = L::pointers(node).as_ref().next; } else { if self.head != Some(node) { return None; } - self.head = T::pointers(node).as_ref().next; + self.head = L::pointers(node).as_ref().next; } - if let Some(next) = T::pointers(node).as_ref().next { - debug_assert_eq!(T::pointers(next).as_ref().prev, Some(node)); - T::pointers(next).as_mut().prev = T::pointers(node).as_ref().prev; + if let Some(next) = L::pointers(node).as_ref().next { + debug_assert_eq!(L::pointers(next).as_ref().prev, Some(node)); + L::pointers(next).as_mut().prev = L::pointers(node).as_ref().prev; } else { // This might be the last item in the list if self.tail != Some(node) { return None; } - self.tail = T::pointers(node).as_ref().prev; + self.tail = L::pointers(node).as_ref().prev; } - T::pointers(node).as_mut().next = None; - T::pointers(node).as_mut().prev = None; + L::pointers(node).as_mut().next = None; + L::pointers(node).as_mut().prev = None; - Some(T::from_raw(node)) + Some(L::from_raw(node)) } } -impl<T: Link> fmt::Debug for LinkedList<T> { +impl<L: Link> fmt::Debug for LinkedList<L, L::Target> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LinkedList") .field("head", &self.head) @@ -173,27 +182,35 @@ impl<T: Link> fmt::Debug for LinkedList<T> { } } -cfg_sync! { - impl<T: Link> LinkedList<T> { - pub(crate) fn last(&self) -> Option<&T::Target> { - let tail = self.tail.as_ref()?; - unsafe { - Some(&*tail.as_ptr()) - } - } +#[cfg(any( + feature = "fs", + all(unix, feature = "process"), + feature = "signal", + feature = "sync", +))] +impl<L: Link> LinkedList<L, L::Target> { + pub(crate) fn last(&self) -> Option<&L::Target> { + let tail = self.tail.as_ref()?; + unsafe { Some(&*tail.as_ptr()) } + } +} + +impl<L: Link> Default for LinkedList<L, L::Target> { + fn default() -> Self { + Self::new() } } // ===== impl Iter ===== -cfg_rt_threaded! { +cfg_rt_multi_thread! { pub(crate) struct Iter<'a, T: Link> { curr: Option<NonNull<T::Target>>, _p: core::marker::PhantomData<&'a T>, } - impl<T: Link> LinkedList<T> { - pub(crate) fn iter(&self) -> Iter<'_, T> { + impl<L: Link> LinkedList<L, L::Target> { + pub(crate) fn iter(&self) -> Iter<'_, L> { Iter { curr: self.head, _p: core::marker::PhantomData, @@ -215,6 +232,52 @@ cfg_rt_threaded! { } } +// ===== impl DrainFilter ===== + +cfg_io_readiness! { + pub(crate) struct DrainFilter<'a, T: Link, F> { + list: &'a mut LinkedList<T, T::Target>, + filter: F, + curr: Option<NonNull<T::Target>>, + } + + impl<T: Link> LinkedList<T, T::Target> { + pub(crate) fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, F> + where + F: FnMut(&mut T::Target) -> bool, + { + let curr = self.head; + DrainFilter { + curr, + filter, + list: self, + } + } + } + + impl<'a, T, F> Iterator for DrainFilter<'a, T, F> + where + T: Link, + F: FnMut(&mut T::Target) -> bool, + { + type Item = T::Handle; + + fn next(&mut self) -> Option<Self::Item> { + while let Some(curr) = self.curr { + // safety: the pointer references data contained by the list + self.curr = unsafe { T::pointers(curr).as_ref() }.next; + + // safety: the value is still owned by the linked list. + if (self.filter)(unsafe { &mut *curr.as_ptr() }) { + return unsafe { self.list.remove(curr) }; + } + } + + None + } + } +} + // ===== impl Pointers ===== impl<T> Pointers<T> { @@ -277,7 +340,7 @@ mod tests { r.as_ref().get_ref().into() } - fn collect_list(list: &mut LinkedList<&'_ Entry>) -> Vec<i32> { + fn collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec<i32> { let mut ret = vec![]; while let Some(entry) = list.pop_back() { @@ -287,7 +350,10 @@ mod tests { ret } - fn push_all<'a>(list: &mut LinkedList<&'a Entry>, entries: &[Pin<&'a Entry>]) { + fn push_all<'a>( + list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>, + entries: &[Pin<&'a Entry>], + ) { for entry in entries.iter() { list.push_front(*entry); } @@ -308,6 +374,11 @@ mod tests { } #[test] + fn const_new() { + const _: LinkedList<&Entry, <&Entry as Link>::Target> = LinkedList::new(); + } + + #[test] fn push_and_drain() { let a = entry(5); let b = entry(7); @@ -332,7 +403,7 @@ mod tests { let a = entry(5); let b = entry(7); - let mut list = LinkedList::<&Entry>::new(); + let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); list.push_front(a.as_ref()); @@ -489,7 +560,7 @@ mod tests { unsafe { // Remove missing - let mut list = LinkedList::<&Entry>::new(); + let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); list.push_front(b.as_ref()); list.push_front(a.as_ref()); @@ -503,7 +574,7 @@ mod tests { let a = entry(5); let b = entry(7); - let mut list = LinkedList::<&Entry>::new(); + let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); assert_eq!(0, list.iter().count()); @@ -543,7 +614,7 @@ mod tests { }) .collect::<Vec<_>>(); - let mut ll = LinkedList::<&Entry>::new(); + let mut ll = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); let mut reference = VecDeque::new(); let entries: Vec<_> = (0..ops.len()).map(|i| entry(i as i32)).collect(); diff --git a/src/util/mod.rs b/src/util/mod.rs index 6dda08c..b2043dd 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -3,16 +3,26 @@ cfg_io_driver! { pub(crate) mod slab; } -#[cfg(any(feature = "sync", feature = "rt-core"))] +#[cfg(any( + feature = "fs", + feature = "net", + feature = "process", + feature = "rt", + feature = "sync", + feature = "signal", +))] pub(crate) mod linked_list; -#[cfg(any(feature = "rt-threaded", feature = "macros", feature = "stream"))] +#[cfg(any(feature = "rt-multi-thread", feature = "macros", feature = "stream"))] mod rand; -mod wake; -pub(crate) use wake::{waker_ref, Wake}; +cfg_rt! { + mod wake; + pub(crate) use wake::WakerRef; + pub(crate) use wake::{waker_ref, Wake}; +} -cfg_rt_threaded! { +cfg_rt_multi_thread! { pub(crate) use rand::FastRand; mod try_lock; @@ -24,5 +34,3 @@ pub(crate) mod trace; #[cfg(any(feature = "macros", feature = "stream"))] #[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] pub use rand::thread_rng_n; - -pub(crate) mod intrusive_double_linked_list; diff --git a/src/util/slab.rs b/src/util/slab.rs new file mode 100644 index 0000000..efc72e1 --- /dev/null +++ b/src/util/slab.rs @@ -0,0 +1,841 @@ +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::{AtomicBool, AtomicUsize}; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::bit; +use std::fmt; +use std::mem; +use std::ops; +use std::ptr; +use std::sync::atomic::Ordering::Relaxed; + +/// Amortized allocation for homogeneous data types. +/// +/// The slab pre-allocates chunks of memory to store values. It uses a similar +/// growing strategy as `Vec`. When new capacity is needed, the slab grows by +/// 2x. +/// +/// # Pages +/// +/// Unlike `Vec`, growing does not require moving existing elements. Instead of +/// being a continuous chunk of memory for all elements, `Slab` is an array of +/// arrays. The top-level array is an array of pages. Each page is 2x bigger +/// than the previous one. When the slab grows, a new page is allocated. +/// +/// Pages are lazily initialized. +/// +/// # Allocating +/// +/// When allocating an object, first previously used slots are reused. If no +/// previously used slot is available, a new slot is initialized in an existing +/// page. If all pages are full, then a new page is allocated. +/// +/// When an allocated object is released, it is pushed into it's page's free +/// list. Allocating scans all pages for a free slot. +/// +/// # Indexing +/// +/// The slab is able to index values using an address. Even when the indexed +/// object has been released, it is still safe to index. This is a key ability +/// for using the slab with the I/O driver. Addresses are registered with the +/// OS's selector and I/O resources can be released without synchronizing with +/// the OS. +/// +/// # Compaction +/// +/// `Slab::compact` will release pages that have been allocated but are no +/// longer used. This is done by scanning the pages and finding pages with no +/// allocated objects. These pages are then freed. +/// +/// # Synchronization +/// +/// The `Slab` structure is able to provide (mostly) unsynchronized reads to +/// values stored in the slab. Insertions and removals are synchronized. Reading +/// objects via `Ref` is fully unsynchronized. Indexing objects uses amortized +/// synchronization. +/// +pub(crate) struct Slab<T> { + /// Array of pages. Each page is synchronized. + pages: [Arc<Page<T>>; NUM_PAGES], + + /// Caches the array pointer & number of initialized slots. + cached: [CachedPage<T>; NUM_PAGES], +} + +/// Allocate values in the associated slab. +pub(crate) struct Allocator<T> { + /// Pages in the slab. The first page has a capacity of 16 elements. Each + /// following page has double the capacity of the previous page. + /// + /// Each returned `Ref` holds a reference count to this `Arc`. + pages: [Arc<Page<T>>; NUM_PAGES], +} + +/// References a slot in the slab. Indexing a slot using an `Address` is memory +/// safe even if the slot has been released or the page has been deallocated. +/// However, it is not guaranteed that the slot has not been reused and is now +/// represents a different value. +/// +/// The I/O driver uses a counter to track the slot's generation. Once accessing +/// the slot, the generations are compared. If they match, the value matches the +/// address. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct Address(usize); + +/// An entry in the slab. +pub(crate) trait Entry: Default { + /// Reset the entry's value and track the generation. + fn reset(&self); +} + +/// A reference to a value stored in the slab +pub(crate) struct Ref<T> { + value: *const Value<T>, +} + +/// Maximum number of pages a slab can contain. +const NUM_PAGES: usize = 19; + +/// Minimum number of slots a page can contain. +const PAGE_INITIAL_SIZE: usize = 32; +const PAGE_INDEX_SHIFT: u32 = PAGE_INITIAL_SIZE.trailing_zeros() + 1; + +/// A page in the slab +struct Page<T> { + /// Slots + slots: Mutex<Slots<T>>, + + // Number of slots currently being used. This is not guaranteed to be up to + // date and should only be used as a hint. + used: AtomicUsize, + + // Set to `true` when the page has been allocated. + allocated: AtomicBool, + + // The number of slots the page can hold. + len: usize, + + // Length of all previous pages combined + prev_len: usize, +} + +struct CachedPage<T> { + /// Pointer to the page's slots. + slots: *const Slot<T>, + + /// Number of initialized slots. + init: usize, +} + +/// Page state +struct Slots<T> { + /// Slots + slots: Vec<Slot<T>>, + + head: usize, + + /// Number of slots currently in use. + used: usize, +} + +unsafe impl<T: Sync> Sync for Page<T> {} +unsafe impl<T: Sync> Send for Page<T> {} +unsafe impl<T: Sync> Sync for CachedPage<T> {} +unsafe impl<T: Sync> Send for CachedPage<T> {} +unsafe impl<T: Sync> Sync for Ref<T> {} +unsafe impl<T: Sync> Send for Ref<T> {} + +/// A slot in the slab. Contains slot-specific metadata. +/// +/// `#[repr(C)]` guarantees that the struct starts w/ `value`. We use pointer +/// math to map a value pointer to an index in the page. +#[repr(C)] +struct Slot<T> { + /// Pointed to by `Ref`. + value: UnsafeCell<Value<T>>, + + /// Next entry in the free list. + next: u32, +} + +/// Value paired with a reference to the page +struct Value<T> { + /// Value stored in the value + value: T, + + /// Pointer to the page containing the slot. + /// + /// A raw pointer is used as this creates a ref cycle. + page: *const Page<T>, +} + +impl<T> Slab<T> { + /// Create a new, empty, slab + pub(crate) fn new() -> Slab<T> { + // Initializing arrays is a bit annoying. Instead of manually writing + // out an array and every single entry, `Default::default()` is used to + // initialize the array, then the array is iterated and each value is + // initialized. + let mut slab = Slab { + pages: Default::default(), + cached: Default::default(), + }; + + let mut len = PAGE_INITIAL_SIZE; + let mut prev_len: usize = 0; + + for page in &mut slab.pages { + let page = Arc::get_mut(page).unwrap(); + page.len = len; + page.prev_len = prev_len; + len *= 2; + prev_len += page.len; + + // Ensure we don't exceed the max address space. + debug_assert!( + page.len - 1 + page.prev_len < (1 << 24), + "max = {:b}", + page.len - 1 + page.prev_len + ); + } + + slab + } + + /// Returns a new `Allocator`. + /// + /// The `Allocator` supports concurrent allocation of objects. + pub(crate) fn allocator(&self) -> Allocator<T> { + Allocator { + pages: self.pages.clone(), + } + } + + /// Returns a reference to the value stored at the given address. + /// + /// `&mut self` is used as the call may update internal cached state. + pub(crate) fn get(&mut self, addr: Address) -> Option<&T> { + let page_idx = addr.page(); + let slot_idx = self.pages[page_idx].slot(addr); + + // If the address references a slot that was last seen as uninitialized, + // the `CachedPage` is updated. This requires acquiring the page lock + // and updating the slot pointer and initialized offset. + if self.cached[page_idx].init <= slot_idx { + self.cached[page_idx].refresh(&self.pages[page_idx]); + } + + // If the address **still** references an uninitialized slot, then the + // address is invalid and `None` is returned. + if self.cached[page_idx].init <= slot_idx { + return None; + } + + // Get a reference to the value. The lifetime of the returned reference + // is bound to `&self`. The only way to invalidate the underlying memory + // is to call `compact()`. The lifetimes prevent calling `compact()` + // while references to values are outstanding. + // + // The referenced data is never mutated. Only `&self` references are + // used and the data is `Sync`. + Some(self.cached[page_idx].get(slot_idx)) + } + + /// Calls the given function with a reference to each slot in the slab. The + /// slot may not be in-use. + /// + /// This is used by the I/O driver during the shutdown process to notify + /// each pending task. + pub(crate) fn for_each(&mut self, mut f: impl FnMut(&T)) { + for page_idx in 0..self.pages.len() { + // It is required to avoid holding the lock when calling the + // provided function. The function may attempt to acquire the lock + // itself. If we hold the lock here while calling `f`, a deadlock + // situation is possible. + // + // Instead of iterating the slots directly in `page`, which would + // require holding the lock, the cache is updated and the slots are + // iterated from the cache. + self.cached[page_idx].refresh(&self.pages[page_idx]); + + for slot_idx in 0..self.cached[page_idx].init { + f(self.cached[page_idx].get(slot_idx)); + } + } + } + + // Release memory back to the allocator. + // + // If pages are empty, the underlying memory is released back to the + // allocator. + pub(crate) fn compact(&mut self) { + // Iterate each page except the very first one. The very first page is + // never freed. + for (idx, page) in self.pages.iter().enumerate().skip(1) { + if page.used.load(Relaxed) != 0 || !page.allocated.load(Relaxed) { + // If the page has slots in use or the memory has not been + // allocated then it cannot be compacted. + continue; + } + + let mut slots = match page.slots.try_lock() { + Some(slots) => slots, + // If the lock cannot be acquired due to being held by another + // thread, don't try to compact the page. + _ => continue, + }; + + if slots.used > 0 || slots.slots.capacity() == 0 { + // The page is in use or it has not yet been allocated. Either + // way, there is no more work to do. + continue; + } + + page.allocated.store(false, Relaxed); + + // Remove the slots vector from the page. This is done so that the + // freeing process is done outside of the lock's critical section. + let vec = mem::replace(&mut slots.slots, vec![]); + slots.head = 0; + + // Drop the lock so we can drop the vector outside the lock below. + drop(slots); + + debug_assert!( + self.cached[idx].slots.is_null() || self.cached[idx].slots == vec.as_ptr(), + "cached = {:?}; actual = {:?}", + self.cached[idx].slots, + vec.as_ptr(), + ); + + // Clear cache + self.cached[idx].slots = ptr::null(); + self.cached[idx].init = 0; + + drop(vec); + } + } +} + +impl<T> fmt::Debug for Slab<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + debug(fmt, "Slab", &self.pages[..]) + } +} + +impl<T: Entry> Allocator<T> { + /// Allocate a new entry and return a handle to the entry. + /// + /// Scans pages from smallest to biggest, stopping when a slot is found. + /// Pages are allocated if necessary. + /// + /// Returns `None` if the slab is full. + pub(crate) fn allocate(&self) -> Option<(Address, Ref<T>)> { + // Find the first available slot. + for page in &self.pages[..] { + if let Some((addr, val)) = Page::allocate(page) { + return Some((addr, val)); + } + } + + None + } +} + +impl<T> fmt::Debug for Allocator<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + debug(fmt, "slab::Allocator", &self.pages[..]) + } +} + +impl<T> ops::Deref for Ref<T> { + type Target = T; + + fn deref(&self) -> &T { + // Safety: `&mut` is never handed out to the underlying value. The page + // is not freed until all `Ref` values are dropped. + unsafe { &(*self.value).value } + } +} + +impl<T> Drop for Ref<T> { + fn drop(&mut self) { + // Safety: `&mut` is never handed out to the underlying value. The page + // is not freed until all `Ref` values are dropped. + let _ = unsafe { (*self.value).release() }; + } +} + +impl<T: fmt::Debug> fmt::Debug for Ref<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(fmt) + } +} + +impl<T: Entry> Page<T> { + // Allocates an object, returns the ref and address. + // + // `self: &Arc<Page<T>>` is avoided here as this would not work with the + // loom `Arc`. + fn allocate(me: &Arc<Page<T>>) -> Option<(Address, Ref<T>)> { + // Before acquiring the lock, use the `used` hint. + if me.used.load(Relaxed) == me.len { + return None; + } + + // Allocating objects requires synchronization + let mut locked = me.slots.lock(); + + if locked.head < locked.slots.len() { + // Re-use an already initialized slot. + // + // Help out the borrow checker + let locked = &mut *locked; + + // Get the index of the slot at the head of the free stack. This is + // the slot that will be reused. + let idx = locked.head; + let slot = &locked.slots[idx]; + + // Update the free stack head to point to the next slot. + locked.head = slot.next as usize; + + // Increment the number of used slots + locked.used += 1; + me.used.store(locked.used, Relaxed); + + // Reset the slot + slot.value.with(|ptr| unsafe { (*ptr).value.reset() }); + + // Return a reference to the slot + Some((me.addr(idx), slot.gen_ref(me))) + } else if me.len == locked.slots.len() { + // The page is full + None + } else { + // No initialized slots are available, but the page has more + // capacity. Initialize a new slot. + let idx = locked.slots.len(); + + if idx == 0 { + // The page has not yet been allocated. Allocate the storage for + // all page slots. + locked.slots.reserve_exact(me.len); + } + + // Initialize a new slot + locked.slots.push(Slot { + value: UnsafeCell::new(Value { + value: Default::default(), + page: &**me as *const _, + }), + next: 0, + }); + + // Increment the head to indicate the free stack is empty + locked.head += 1; + + // Increment the number of used slots + locked.used += 1; + me.used.store(locked.used, Relaxed); + me.allocated.store(true, Relaxed); + + debug_assert_eq!(locked.slots.len(), locked.head); + + Some((me.addr(idx), locked.slots[idx].gen_ref(me))) + } + } +} + +impl<T> Page<T> { + /// Returns the slot index within the current page referenced by the given + /// address. + fn slot(&self, addr: Address) -> usize { + addr.0 - self.prev_len + } + + /// Returns the address for the given slot + fn addr(&self, slot: usize) -> Address { + Address(slot + self.prev_len) + } +} + +impl<T> Default for Page<T> { + fn default() -> Page<T> { + Page { + used: AtomicUsize::new(0), + allocated: AtomicBool::new(false), + slots: Mutex::new(Slots { + slots: Vec::new(), + head: 0, + used: 0, + }), + len: 0, + prev_len: 0, + } + } +} + +impl<T> Page<T> { + /// Release a slot into the page's free list + fn release(&self, value: *const Value<T>) { + let mut locked = self.slots.lock(); + + let idx = locked.index_for(value); + locked.slots[idx].next = locked.head as u32; + locked.head = idx; + locked.used -= 1; + + self.used.store(locked.used, Relaxed); + } +} + +impl<T> CachedPage<T> { + /// Refresh the cache + fn refresh(&mut self, page: &Page<T>) { + let slots = page.slots.lock(); + + if !slots.slots.is_empty() { + self.slots = slots.slots.as_ptr(); + self.init = slots.slots.len(); + } + } + + // Get a value by index + fn get(&self, idx: usize) -> &T { + assert!(idx < self.init); + + // Safety: Pages are allocated concurrently, but are only ever + // **deallocated** by `Slab`. `Slab` will always have a more + // conservative view on the state of the slot array. Once `CachedPage` + // sees a slot pointer and initialized offset, it will remain valid + // until `compact()` is called. The `compact()` function also updates + // `CachedPage`. + unsafe { + let slot = self.slots.add(idx); + let value = slot as *const Value<T>; + + &(*value).value + } + } +} + +impl<T> Default for CachedPage<T> { + fn default() -> CachedPage<T> { + CachedPage { + slots: ptr::null(), + init: 0, + } + } +} + +impl<T> Slots<T> { + /// Maps a slot pointer to an offset within the current page. + /// + /// The pointer math removes the `usize` index from the `Ref` struct, + /// shrinking the struct to a single pointer size. The contents of the + /// function is safe, the resulting `usize` is bounds checked before being + /// used. + /// + /// # Panics + /// + /// panics if the provided slot pointer is not contained by the page. + fn index_for(&self, slot: *const Value<T>) -> usize { + use std::mem; + + let base = &self.slots[0] as *const _ as usize; + + assert!(base != 0, "page is unallocated"); + + let slot = slot as usize; + let width = mem::size_of::<Slot<T>>(); + + assert!(slot >= base, "unexpected pointer"); + + let idx = (slot - base) / width; + assert!(idx < self.slots.len() as usize); + + idx + } +} + +impl<T: Entry> Slot<T> { + /// Generates a `Ref` for the slot. This involves bumping the page's ref count. + fn gen_ref(&self, page: &Arc<Page<T>>) -> Ref<T> { + // The ref holds a ref on the page. The `Arc` is forgotten here and is + // resurrected in `release` when the `Ref` is dropped. By avoiding to + // hold on to an explicit `Arc` value, the struct size of `Ref` is + // reduced. + mem::forget(page.clone()); + let slot = self as *const Slot<T>; + let value = slot as *const Value<T>; + + Ref { value } + } +} + +impl<T> Value<T> { + // Release the slot, returning the `Arc<Page<T>>` logically owned by the ref. + fn release(&self) -> Arc<Page<T>> { + // Safety: called by `Ref`, which owns an `Arc<Page<T>>` instance. + let page = unsafe { Arc::from_raw(self.page) }; + page.release(self as *const _); + page + } +} + +impl Address { + fn page(self) -> usize { + // Since every page is twice as large as the previous page, and all page + // sizes are powers of two, we can determine the page index that + // contains a given address by shifting the address down by the smallest + // page size and looking at how many twos places necessary to represent + // that number, telling us what power of two page size it fits inside + // of. We can determine the number of twos places by counting the number + // of leading zeros (unused twos places) in the number's binary + // representation, and subtracting that count from the total number of + // bits in a word. + let slot_shifted = (self.0 + PAGE_INITIAL_SIZE) >> PAGE_INDEX_SHIFT; + (bit::pointer_width() - slot_shifted.leading_zeros()) as usize + } + + pub(crate) const fn as_usize(self) -> usize { + self.0 + } + + pub(crate) fn from_usize(src: usize) -> Address { + Address(src) + } +} + +fn debug<T>(fmt: &mut fmt::Formatter<'_>, name: &str, pages: &[Arc<Page<T>>]) -> fmt::Result { + let mut capacity = 0; + let mut len = 0; + + for page in pages { + if page.allocated.load(Relaxed) { + capacity += page.len; + len += page.used.load(Relaxed); + } + } + + fmt.debug_struct(name) + .field("len", &len) + .field("capacity", &capacity) + .finish() +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering::SeqCst; + + struct Foo { + cnt: AtomicUsize, + id: AtomicUsize, + } + + impl Default for Foo { + fn default() -> Foo { + Foo { + cnt: AtomicUsize::new(0), + id: AtomicUsize::new(0), + } + } + } + + impl Entry for Foo { + fn reset(&self) { + self.cnt.fetch_add(1, SeqCst); + } + } + + #[test] + fn insert_remove() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + + let (addr1, foo1) = alloc.allocate().unwrap(); + foo1.id.store(1, SeqCst); + assert_eq!(0, foo1.cnt.load(SeqCst)); + + let (addr2, foo2) = alloc.allocate().unwrap(); + foo2.id.store(2, SeqCst); + assert_eq!(0, foo2.cnt.load(SeqCst)); + + assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst)); + assert_eq!(2, slab.get(addr2).unwrap().id.load(SeqCst)); + + drop(foo1); + + assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst)); + + let (addr3, foo3) = alloc.allocate().unwrap(); + assert_eq!(addr3, addr1); + assert_eq!(1, foo3.cnt.load(SeqCst)); + foo3.id.store(3, SeqCst); + assert_eq!(3, slab.get(addr3).unwrap().id.load(SeqCst)); + + drop(foo2); + drop(foo3); + + slab.compact(); + + // The first page is never released + assert!(slab.get(addr1).is_some()); + assert!(slab.get(addr2).is_some()); + assert!(slab.get(addr3).is_some()); + } + + #[test] + fn insert_many() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for i in 0..10_000 { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + entries.push((addr, val)); + } + + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(i, v.id.load(SeqCst)); + assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + + entries.clear(); + + for i in 0..10_000 { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(10_000 - i, SeqCst); + entries.push((addr, val)); + } + + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(10_000 - i, v.id.load(SeqCst)); + assert_eq!(10_000 - i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + + #[test] + fn insert_drop_reverse() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for i in 0..10_000 { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + entries.push((addr, val)); + } + + for _ in 0..10 { + // Drop 1000 in reverse + for _ in 0..1_000 { + entries.pop(); + } + + // Check remaining + for (i, (addr, v)) in entries.iter().enumerate() { + assert_eq!(i, v.id.load(SeqCst)); + assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + } + + #[test] + fn no_compaction_if_page_still_in_use() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries1 = vec![]; + let mut entries2 = vec![]; + + for i in 0..10_000 { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + if i % 2 == 0 { + entries1.push((addr, val, i)); + } else { + entries2.push(val); + } + } + + drop(entries2); + + for (addr, _, i) in &entries1 { + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + + #[test] + fn compact_all() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for _ in 0..2 { + entries.clear(); + + for i in 0..10_000 { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + entries.push((addr, val)); + } + + let mut addrs = vec![]; + + for (addr, _) in entries.drain(..) { + addrs.push(addr); + } + + slab.compact(); + + // The first page is never freed + for addr in &addrs[PAGE_INITIAL_SIZE..] { + assert!(slab.get(*addr).is_none()); + } + } + } + + #[test] + fn issue_3014() { + let mut slab = Slab::<Foo>::new(); + let alloc = slab.allocator(); + let mut entries = vec![]; + + for _ in 0..5 { + entries.clear(); + + // Allocate a few pages + 1 + for i in 0..(32 + 64 + 128 + 1) { + let (addr, val) = alloc.allocate().unwrap(); + val.id.store(i, SeqCst); + + entries.push((addr, val, i)); + } + + for (addr, val, i) in &entries { + assert_eq!(*i, val.id.load(SeqCst)); + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + + // Release the last entry + entries.pop(); + + // Compact + slab.compact(); + + // Check all the addresses + + for (addr, val, i) in &entries { + assert_eq!(*i, val.id.load(SeqCst)); + assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst)); + } + } + } +} diff --git a/src/util/slab/addr.rs b/src/util/slab/addr.rs deleted file mode 100644 index c14e32e..0000000 --- a/src/util/slab/addr.rs +++ /dev/null @@ -1,154 +0,0 @@ -//! Tracks the location of an entry in a slab. -//! -//! # Index packing -//! -//! A slab index consists of multiple indices packed into a single `usize` value -//! that correspond to different parts of the slab. -//! -//! The least significant `MAX_PAGES + INITIAL_PAGE_SIZE.trailing_zeros() + 1` -//! bits store the address within a shard, starting at 0 for the first slot on -//! the first page. To index a slot within a shard, we first find the index of -//! the page that the address falls on, and then the offset of the slot within -//! that page. -//! -//! Since every page is twice as large as the previous page, and all page sizes -//! are powers of two, we can determine the page index that contains a given -//! address by shifting the address down by the smallest page size and looking -//! at how many twos places necessary to represent that number, telling us what -//! power of two page size it fits inside of. We can determine the number of -//! twos places by counting the number of leading zeros (unused twos places) in -//! the number's binary representation, and subtracting that count from the -//! total number of bits in a word. -//! -//! Once we know what page contains an address, we can subtract the size of all -//! previous pages from the address to determine the offset within the page. -//! -//! After the page address, the next `MAX_THREADS.trailing_zeros() + 1` least -//! significant bits are the thread ID. These are used to index the array of -//! shards to find which shard a slot belongs to. If an entry is being removed -//! and the thread ID of its index matches that of the current thread, we can -//! use the `remove_local` fast path; otherwise, we have to use the synchronized -//! `remove_remote` path. -//! -//! Finally, a generation value is packed into the index. The `RESERVED_BITS` -//! most significant bits are left unused, and the remaining bits between the -//! last bit of the thread ID and the first reserved bit are used to store the -//! generation. The generation is used as part of an atomic read-modify-write -//! loop every time a `ScheduledIo`'s readiness is modified, or when the -//! resource is removed, to guard against the ABA problem. -//! -//! Visualized: -//! -//! ```text -//! ┌──────────┬───────────────┬──────────────────┬──────────────────────────┐ -//! │ reserved │ generation │ thread ID │ address │ -//! └▲─────────┴▲──────────────┴▲─────────────────┴▲────────────────────────▲┘ -//! │ │ │ │ │ -//! bits(usize) │ bits(MAX_THREADS) │ 0 -//! │ │ -//! bits(usize) - RESERVED MAX_PAGES + bits(INITIAL_PAGE_SIZE) -//! ``` - -use crate::util::bit; -use crate::util::slab::{Generation, INITIAL_PAGE_SIZE, MAX_PAGES, MAX_THREADS}; - -use std::usize; - -/// References the location at which an entry is stored in a slab. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub(crate) struct Address(usize); - -const PAGE_INDEX_SHIFT: u32 = INITIAL_PAGE_SIZE.trailing_zeros() + 1; - -/// Address in the shard -const SLOT: bit::Pack = bit::Pack::least_significant(MAX_PAGES as u32 + PAGE_INDEX_SHIFT); - -/// Masks the thread identifier -const THREAD: bit::Pack = SLOT.then(MAX_THREADS.trailing_zeros() + 1); - -/// Masks the generation -const GENERATION: bit::Pack = THREAD - .then(bit::pointer_width().wrapping_sub(RESERVED.width() + THREAD.width() + SLOT.width())); - -// Chosen arbitrarily -const RESERVED: bit::Pack = bit::Pack::most_significant(5); - -impl Address { - /// Represents no entry, picked to avoid collision with Mio's internals. - /// This value should not be passed to mio. - pub(crate) const NULL: usize = usize::MAX >> 1; - - /// Re-exported by `Generation`. - pub(super) const GENERATION_WIDTH: u32 = GENERATION.width(); - - pub(super) fn new(shard_index: usize, generation: Generation) -> Address { - let mut repr = 0; - - repr = SLOT.pack(shard_index, repr); - repr = GENERATION.pack(generation.to_usize(), repr); - - Address(repr) - } - - /// Convert from a `usize` representation. - pub(crate) fn from_usize(src: usize) -> Address { - assert_ne!(src, Self::NULL); - - Address(src) - } - - /// Convert to a `usize` representation - pub(crate) fn to_usize(self) -> usize { - self.0 - } - - pub(crate) fn generation(self) -> Generation { - Generation::new(GENERATION.unpack(self.0)) - } - - /// Returns the page index - pub(super) fn page(self) -> usize { - // Since every page is twice as large as the previous page, and all page - // sizes are powers of two, we can determine the page index that - // contains a given address by shifting the address down by the smallest - // page size and looking at how many twos places necessary to represent - // that number, telling us what power of two page size it fits inside - // of. We can determine the number of twos places by counting the number - // of leading zeros (unused twos places) in the number's binary - // representation, and subtracting that count from the total number of - // bits in a word. - let slot_shifted = (self.slot() + INITIAL_PAGE_SIZE) >> PAGE_INDEX_SHIFT; - (bit::pointer_width() - slot_shifted.leading_zeros()) as usize - } - - /// Returns the slot index - pub(super) fn slot(self) -> usize { - SLOT.unpack(self.0) - } -} - -#[cfg(test)] -cfg_not_loom! { - use proptest::proptest; - - #[test] - fn test_pack_format() { - assert_eq!(5, RESERVED.width()); - assert_eq!(0b11111, RESERVED.max_value()); - } - - proptest! { - #[test] - fn address_roundtrips( - slot in 0usize..SLOT.max_value(), - generation in 0usize..Generation::MAX, - ) { - let address = Address::new(slot, Generation::new(generation)); - // Round trip - let address = Address::from_usize(address.to_usize()); - - assert_eq!(address.slot(), slot); - assert_eq!(address.generation().to_usize(), generation); - } - } -} diff --git a/src/util/slab/entry.rs b/src/util/slab/entry.rs deleted file mode 100644 index 2e0b10b..0000000 --- a/src/util/slab/entry.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::util::slab::Generation; - -pub(crate) trait Entry: Default { - fn generation(&self) -> Generation; - - fn reset(&self, generation: Generation) -> bool; -} diff --git a/src/util/slab/generation.rs b/src/util/slab/generation.rs deleted file mode 100644 index 4b16b2c..0000000 --- a/src/util/slab/generation.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::util::bit; -use crate::util::slab::Address; - -/// An mutation identifier for a slot in the slab. The generation helps prevent -/// accessing an entry with an outdated token. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)] -pub(crate) struct Generation(usize); - -impl Generation { - pub(crate) const WIDTH: u32 = Address::GENERATION_WIDTH; - - pub(super) const MAX: usize = bit::mask_for(Address::GENERATION_WIDTH); - - /// Create a new generation - /// - /// # Panics - /// - /// Panics if `value` is greater than max generation. - pub(crate) fn new(value: usize) -> Generation { - assert!(value <= Self::MAX); - Generation(value) - } - - /// Returns the next generation value - pub(crate) fn next(self) -> Generation { - Generation((self.0 + 1) & Self::MAX) - } - - pub(crate) fn to_usize(self) -> usize { - self.0 - } -} diff --git a/src/util/slab/mod.rs b/src/util/slab/mod.rs deleted file mode 100644 index 5082970..0000000 --- a/src/util/slab/mod.rs +++ /dev/null @@ -1,107 +0,0 @@ -//! A lock-free concurrent slab. - -mod addr; -pub(crate) use addr::Address; - -mod entry; -pub(crate) use entry::Entry; - -mod generation; -pub(crate) use generation::Generation; - -mod page; - -mod shard; -use shard::Shard; - -mod slot; -use slot::Slot; - -mod stack; -use stack::TransferStack; - -#[cfg(all(loom, test))] -mod tests; - -use crate::loom::sync::Mutex; -use crate::util::bit; - -use std::fmt; - -#[cfg(target_pointer_width = "64")] -const MAX_THREADS: usize = 4096; - -#[cfg(target_pointer_width = "32")] -const MAX_THREADS: usize = 2048; - -/// Max number of pages per slab -const MAX_PAGES: usize = bit::pointer_width() as usize / 4; - -cfg_not_loom! { - /// Size of first page - const INITIAL_PAGE_SIZE: usize = 32; -} - -cfg_loom! { - const INITIAL_PAGE_SIZE: usize = 2; -} - -/// A sharded slab. -pub(crate) struct Slab<T> { - // Signal shard for now. Eventually there will be more. - shard: Shard<T>, - local: Mutex<()>, -} - -unsafe impl<T: Send> Send for Slab<T> {} -unsafe impl<T: Sync> Sync for Slab<T> {} - -impl<T: Entry> Slab<T> { - /// Returns a new slab with the default configuration parameters. - pub(crate) fn new() -> Slab<T> { - Slab { - shard: Shard::new(), - local: Mutex::new(()), - } - } - - /// allocs a value into the slab, returning a key that can be used to - /// access it. - /// - /// If this function returns `None`, then the shard for the current thread - /// is full and no items can be added until some are removed, or the maximum - /// number of shards has been reached. - pub(crate) fn alloc(&self) -> Option<Address> { - // we must lock the slab to alloc an item. - let _local = self.local.lock().unwrap(); - self.shard.alloc() - } - - /// Removes the value associated with the given key from the slab. - pub(crate) fn remove(&self, idx: Address) { - // try to lock the slab so that we can use `remove_local`. - let lock = self.local.try_lock(); - - // if we were able to lock the slab, we are "local" and can use the fast - // path; otherwise, we will use `remove_remote`. - if lock.is_ok() { - self.shard.remove_local(idx) - } else { - self.shard.remove_remote(idx) - } - } - - /// Return a reference to the value associated with the given key. - /// - /// If the slab does not contain a value for the given key, `None` is - /// returned instead. - pub(crate) fn get(&self, token: Address) -> Option<&T> { - self.shard.get(token) - } -} - -impl<T> fmt::Debug for Slab<T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Slab").field("shard", &self.shard).finish() - } -} diff --git a/src/util/slab/page.rs b/src/util/slab/page.rs deleted file mode 100644 index 0000e93..0000000 --- a/src/util/slab/page.rs +++ /dev/null @@ -1,187 +0,0 @@ -use crate::loom::cell::UnsafeCell; -use crate::util::slab::{Address, Entry, Slot, TransferStack, INITIAL_PAGE_SIZE}; - -use std::fmt; - -/// Data accessed only by the thread that owns the shard. -pub(crate) struct Local { - head: UnsafeCell<usize>, -} - -/// Data accessed by any thread. -pub(crate) struct Shared<T> { - remote: TransferStack, - size: usize, - prev_sz: usize, - slab: UnsafeCell<Option<Box<[Slot<T>]>>>, -} - -/// Returns the size of the page at index `n` -pub(super) fn size(n: usize) -> usize { - INITIAL_PAGE_SIZE << n -} - -impl Local { - pub(crate) fn new() -> Self { - Self { - head: UnsafeCell::new(0), - } - } - - fn head(&self) -> usize { - self.head.with(|head| unsafe { *head }) - } - - fn set_head(&self, new_head: usize) { - self.head.with_mut(|head| unsafe { - *head = new_head; - }) - } -} - -impl<T: Entry> Shared<T> { - pub(crate) fn new(size: usize, prev_sz: usize) -> Shared<T> { - Self { - prev_sz, - size, - remote: TransferStack::new(), - slab: UnsafeCell::new(None), - } - } - - /// Allocates storage for this page if it does not allready exist. - /// - /// This requires unique access to the page (e.g. it is called from the - /// thread that owns the page, or, in the case of `SingleShard`, while the - /// lock is held). In order to indicate this, a reference to the page's - /// `Local` data is taken by this function; the `Local` argument is not - /// actually used, but requiring it ensures that this is only called when - /// local access is held. - #[cold] - fn alloc_page(&self, _: &Local) { - debug_assert!(self.slab.with(|s| unsafe { (*s).is_none() })); - - let mut slab = Vec::with_capacity(self.size); - slab.extend((1..self.size).map(Slot::new)); - slab.push(Slot::new(Address::NULL)); - - self.slab.with_mut(|s| { - // this mut access is safe — it only occurs to initially - // allocate the page, which only happens on this thread; if the - // page has not yet been allocated, other threads will not try - // to access it yet. - unsafe { - *s = Some(slab.into_boxed_slice()); - } - }); - } - - pub(crate) fn alloc(&self, local: &Local) -> Option<Address> { - let head = local.head(); - - // are there any items on the local free list? (fast path) - let head = if head < self.size { - head - } else { - // if the local free list is empty, pop all the items on the remote - // free list onto the local free list. - self.remote.pop_all()? - }; - - // if the head is still null, both the local and remote free lists are - // empty --- we can't fit any more items on this page. - if head == Address::NULL { - return None; - } - - // do we need to allocate storage for this page? - let page_needs_alloc = self.slab.with(|s| unsafe { (*s).is_none() }); - if page_needs_alloc { - self.alloc_page(local); - } - - let gen = self.slab.with(|slab| { - let slab = unsafe { &*(slab) } - .as_ref() - .expect("page must have been allocated to alloc!"); - - let slot = &slab[head]; - - local.set_head(slot.next()); - slot.generation() - }); - - let index = head + self.prev_sz; - - Some(Address::new(index, gen)) - } - - pub(crate) fn get(&self, addr: Address) -> Option<&T> { - let page_offset = addr.slot() - self.prev_sz; - - self.slab - .with(|slab| unsafe { &*slab }.as_ref()?.get(page_offset)) - .map(|slot| slot.get()) - } - - pub(crate) fn remove_local(&self, local: &Local, addr: Address) { - let offset = addr.slot() - self.prev_sz; - - self.slab.with(|slab| { - let slab = unsafe { &*slab }.as_ref(); - - let slot = if let Some(slot) = slab.and_then(|slab| slab.get(offset)) { - slot - } else { - return; - }; - - if slot.reset(addr.generation()) { - slot.set_next(local.head()); - local.set_head(offset); - } - }) - } - - pub(crate) fn remove_remote(&self, addr: Address) { - let offset = addr.slot() - self.prev_sz; - - self.slab.with(|slab| { - let slab = unsafe { &*slab }.as_ref(); - - let slot = if let Some(slot) = slab.and_then(|slab| slab.get(offset)) { - slot - } else { - return; - }; - - if !slot.reset(addr.generation()) { - return; - } - - self.remote.push(offset, |next| slot.set_next(next)); - }) - } -} - -impl fmt::Debug for Local { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.head.with(|head| { - let head = unsafe { *head }; - f.debug_struct("Local") - .field("head", &format_args!("{:#0x}", head)) - .finish() - }) - } -} - -impl<T> fmt::Debug for Shared<T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Shared") - .field("remote", &self.remote) - .field("prev_sz", &self.prev_sz) - .field("size", &self.size) - // .field("slab", &self.slab) - .finish() - } -} diff --git a/src/util/slab/shard.rs b/src/util/slab/shard.rs deleted file mode 100644 index eaca6f6..0000000 --- a/src/util/slab/shard.rs +++ /dev/null @@ -1,105 +0,0 @@ -use crate::util::slab::{page, Address, Entry, MAX_PAGES}; - -use std::fmt; - -// ┌─────────────┐ ┌────────┐ -// │ page 1 │ │ │ -// ├─────────────┤ ┌───▶│ next──┼─┐ -// │ page 2 │ │ ├────────┤ │ -// │ │ │ │XXXXXXXX│ │ -// │ local_free──┼─┘ ├────────┤ │ -// │ global_free─┼─┐ │ │◀┘ -// ├─────────────┤ └───▶│ next──┼─┐ -// │ page 3 │ ├────────┤ │ -// └─────────────┘ │XXXXXXXX│ │ -// ... ├────────┤ │ -// ┌─────────────┐ │XXXXXXXX│ │ -// │ page n │ ├────────┤ │ -// └─────────────┘ │ │◀┘ -// │ next──┼───▶ -// ├────────┤ -// │XXXXXXXX│ -// └────────┘ -// ... -pub(super) struct Shard<T> { - /// The local free list for each page. - /// - /// These are only ever accessed from this shard's thread, so they are - /// stored separately from the shared state for the page that can be - /// accessed concurrently, to minimize false sharing. - local: Box<[page::Local]>, - /// The shared state for each page in this shard. - /// - /// This consists of the page's metadata (size, previous size), remote free - /// list, and a pointer to the actual array backing that page. - shared: Box<[page::Shared<T>]>, -} - -impl<T: Entry> Shard<T> { - pub(super) fn new() -> Shard<T> { - let mut total_sz = 0; - let shared = (0..MAX_PAGES) - .map(|page_num| { - let sz = page::size(page_num); - let prev_sz = total_sz; - total_sz += sz; - page::Shared::new(sz, prev_sz) - }) - .collect(); - - let local = (0..MAX_PAGES).map(|_| page::Local::new()).collect(); - - Shard { local, shared } - } - - pub(super) fn alloc(&self) -> Option<Address> { - // Can we fit the value into an existing page? - for (page_idx, page) in self.shared.iter().enumerate() { - let local = self.local(page_idx); - - if let Some(page_offset) = page.alloc(local) { - return Some(page_offset); - } - } - - None - } - - pub(super) fn get(&self, addr: Address) -> Option<&T> { - let page_idx = addr.page(); - - if page_idx > self.shared.len() { - return None; - } - - self.shared[page_idx].get(addr) - } - - /// Remove an item on the shard's local thread. - pub(super) fn remove_local(&self, addr: Address) { - let page_idx = addr.page(); - - if let Some(page) = self.shared.get(page_idx) { - page.remove_local(self.local(page_idx), addr); - } - } - - /// Remove an item, while on a different thread from the shard's local thread. - pub(super) fn remove_remote(&self, addr: Address) { - if let Some(page) = self.shared.get(addr.page()) { - page.remove_remote(addr); - } - } - - fn local(&self, i: usize) -> &page::Local { - &self.local[i] - } -} - -impl<T> fmt::Debug for Shard<T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Shard") - .field("shared", &self.shared) - .finish() - } -} diff --git a/src/util/slab/slot.rs b/src/util/slab/slot.rs deleted file mode 100644 index 0608b26..0000000 --- a/src/util/slab/slot.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::loom::cell::UnsafeCell; -use crate::util::slab::{Entry, Generation}; - -/// Stores an entry in the slab. -pub(super) struct Slot<T> { - next: UnsafeCell<usize>, - entry: T, -} - -impl<T: Entry> Slot<T> { - /// Initialize a new `Slot` linked to `next`. - /// - /// The entry is initialized to a default value. - pub(super) fn new(next: usize) -> Slot<T> { - Slot { - next: UnsafeCell::new(next), - entry: T::default(), - } - } - - pub(super) fn get(&self) -> &T { - &self.entry - } - - pub(super) fn generation(&self) -> Generation { - self.entry.generation() - } - - pub(super) fn reset(&self, generation: Generation) -> bool { - self.entry.reset(generation) - } - - pub(super) fn next(&self) -> usize { - self.next.with(|next| unsafe { *next }) - } - - pub(super) fn set_next(&self, next: usize) { - self.next.with_mut(|n| unsafe { - (*n) = next; - }) - } -} diff --git a/src/util/slab/stack.rs b/src/util/slab/stack.rs deleted file mode 100644 index 0ae0d71..0000000 --- a/src/util/slab/stack.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::loom::sync::atomic::AtomicUsize; -use crate::util::slab::Address; - -use std::fmt; -use std::sync::atomic::Ordering; -use std::usize; - -pub(super) struct TransferStack { - head: AtomicUsize, -} - -impl TransferStack { - pub(super) fn new() -> Self { - Self { - head: AtomicUsize::new(Address::NULL), - } - } - - pub(super) fn pop_all(&self) -> Option<usize> { - let val = self.head.swap(Address::NULL, Ordering::Acquire); - - if val == Address::NULL { - None - } else { - Some(val) - } - } - - pub(super) fn push(&self, value: usize, before: impl Fn(usize)) { - let mut next = self.head.load(Ordering::Relaxed); - - loop { - before(next); - - match self - .head - .compare_exchange(next, value, Ordering::AcqRel, Ordering::Acquire) - { - // lost the race! - Err(actual) => next = actual, - Ok(_) => return, - } - } - } -} - -impl fmt::Debug for TransferStack { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Loom likes to dump all its internal state in `fmt::Debug` impls, so - // we override this to just print the current value in tests. - f.debug_struct("TransferStack") - .field( - "head", - &format_args!("{:#x}", self.head.load(Ordering::Relaxed)), - ) - .finish() - } -} diff --git a/src/util/slab/tests/loom_slab.rs b/src/util/slab/tests/loom_slab.rs deleted file mode 100644 index 48e94f0..0000000 --- a/src/util/slab/tests/loom_slab.rs +++ /dev/null @@ -1,327 +0,0 @@ -use crate::io::driver::ScheduledIo; -use crate::util::slab::{Address, Slab}; - -use loom::sync::{Arc, Condvar, Mutex}; -use loom::thread; - -#[test] -fn local_remove() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - - let s = slab.clone(); - let t1 = thread::spawn(move || { - let idx = store_val(&s, 1); - assert_eq!(get_val(&s, idx), Some(1)); - s.remove(idx); - assert_eq!(get_val(&s, idx), None); - let idx = store_val(&s, 2); - assert_eq!(get_val(&s, idx), Some(2)); - s.remove(idx); - assert_eq!(get_val(&s, idx), None); - }); - - let s = slab.clone(); - let t2 = thread::spawn(move || { - let idx = store_val(&s, 3); - assert_eq!(get_val(&s, idx), Some(3)); - s.remove(idx); - assert_eq!(get_val(&s, idx), None); - let idx = store_val(&s, 4); - s.remove(idx); - assert_eq!(get_val(&s, idx), None); - }); - - let s = slab; - let idx1 = store_val(&s, 5); - assert_eq!(get_val(&s, idx1), Some(5)); - let idx2 = store_val(&s, 6); - assert_eq!(get_val(&s, idx2), Some(6)); - s.remove(idx1); - assert_eq!(get_val(&s, idx1), None); - assert_eq!(get_val(&s, idx2), Some(6)); - s.remove(idx2); - assert_eq!(get_val(&s, idx2), None); - - t1.join().expect("thread 1 should not panic"); - t2.join().expect("thread 2 should not panic"); - }); -} - -#[test] -fn remove_remote() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - - let idx1 = store_val(&slab, 1); - assert_eq!(get_val(&slab, idx1), Some(1)); - - let idx2 = store_val(&slab, 2); - assert_eq!(get_val(&slab, idx2), Some(2)); - - let idx3 = store_val(&slab, 3); - assert_eq!(get_val(&slab, idx3), Some(3)); - - let s = slab.clone(); - let t1 = thread::spawn(move || { - assert_eq!(get_val(&s, idx2), Some(2)); - s.remove(idx2); - assert_eq!(get_val(&s, idx2), None); - }); - - let s = slab.clone(); - let t2 = thread::spawn(move || { - assert_eq!(get_val(&s, idx3), Some(3)); - s.remove(idx3); - assert_eq!(get_val(&s, idx3), None); - }); - - t1.join().expect("thread 1 should not panic"); - t2.join().expect("thread 2 should not panic"); - - assert_eq!(get_val(&slab, idx1), Some(1)); - assert_eq!(get_val(&slab, idx2), None); - assert_eq!(get_val(&slab, idx3), None); - }); -} - -#[test] -fn remove_remote_and_reuse() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - - let idx1 = store_val(&slab, 1); - let idx2 = store_val(&slab, 2); - - assert_eq!(get_val(&slab, idx1), Some(1)); - assert_eq!(get_val(&slab, idx2), Some(2)); - - let s = slab.clone(); - let t1 = thread::spawn(move || { - s.remove(idx1); - let value = get_val(&s, idx1); - - // We may or may not see the new value yet, depending on when - // this occurs, but we must either see the new value or `None`; - // the old value has been removed! - assert!(value == None || value == Some(3)); - }); - - let idx3 = store_when_free(&slab, 3); - t1.join().expect("thread 1 should not panic"); - - assert_eq!(get_val(&slab, idx3), Some(3)); - assert_eq!(get_val(&slab, idx2), Some(2)); - }); -} - -#[test] -fn concurrent_alloc_remove() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - let pair = Arc::new((Mutex::new(None), Condvar::new())); - - let slab2 = slab.clone(); - let pair2 = pair.clone(); - let remover = thread::spawn(move || { - let (lock, cvar) = &*pair2; - for _ in 0..2 { - let mut next = lock.lock().unwrap(); - while next.is_none() { - next = cvar.wait(next).unwrap(); - } - let key = next.take().unwrap(); - slab2.remove(key); - assert_eq!(get_val(&slab2, key), None); - cvar.notify_one(); - } - }); - - let (lock, cvar) = &*pair; - for i in 0..2 { - let key = store_val(&slab, i); - - let mut next = lock.lock().unwrap(); - *next = Some(key); - cvar.notify_one(); - - // Wait for the item to be removed. - while next.is_some() { - next = cvar.wait(next).unwrap(); - } - - assert_eq!(get_val(&slab, key), None); - } - - remover.join().unwrap(); - }) -} - -#[test] -fn concurrent_remove_remote_and_reuse() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - - let idx1 = store_val(&slab, 1); - let idx2 = store_val(&slab, 2); - - assert_eq!(get_val(&slab, idx1), Some(1)); - assert_eq!(get_val(&slab, idx2), Some(2)); - - let s = slab.clone(); - let s2 = slab.clone(); - let t1 = thread::spawn(move || { - s.remove(idx1); - }); - - let t2 = thread::spawn(move || { - s2.remove(idx2); - }); - - let idx3 = store_when_free(&slab, 3); - t1.join().expect("thread 1 should not panic"); - t2.join().expect("thread 1 should not panic"); - - assert!(get_val(&slab, idx1).is_none()); - assert!(get_val(&slab, idx2).is_none()); - assert_eq!(get_val(&slab, idx3), Some(3)); - }); -} - -#[test] -fn alloc_remove_get() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - let pair = Arc::new((Mutex::new(None), Condvar::new())); - - let slab2 = slab.clone(); - let pair2 = pair.clone(); - let t1 = thread::spawn(move || { - let slab = slab2; - let (lock, cvar) = &*pair2; - // allocate one entry just so that we have to use the final one for - // all future allocations. - let _key0 = store_val(&slab, 0); - let key = store_val(&slab, 1); - - let mut next = lock.lock().unwrap(); - *next = Some(key); - cvar.notify_one(); - // remove the second entry - slab.remove(key); - // store a new readiness at the same location (since the slab - // already has an entry in slot 0) - store_val(&slab, 2); - }); - - let (lock, cvar) = &*pair; - // wait for the second entry to be stored... - let mut next = lock.lock().unwrap(); - while next.is_none() { - next = cvar.wait(next).unwrap(); - } - let key = next.unwrap(); - - // our generation will be stale when the second store occurs at that - // index, we must not see the value of that store. - let val = get_val(&slab, key); - assert_ne!(val, Some(2), "generation must have advanced!"); - - t1.join().unwrap(); - }) -} - -#[test] -fn alloc_remove_set() { - loom::model(|| { - let slab = Arc::new(Slab::new()); - let pair = Arc::new((Mutex::new(None), Condvar::new())); - - let slab2 = slab.clone(); - let pair2 = pair.clone(); - let t1 = thread::spawn(move || { - let slab = slab2; - let (lock, cvar) = &*pair2; - // allocate one entry just so that we have to use the final one for - // all future allocations. - let _key0 = store_val(&slab, 0); - let key = store_val(&slab, 1); - - let mut next = lock.lock().unwrap(); - *next = Some(key); - cvar.notify_one(); - - slab.remove(key); - // remove the old entry and insert a new one, with a new generation. - let key2 = slab.alloc().expect("store key 2"); - // after the remove, we must not see the value written with the - // stale index. - assert_eq!( - get_val(&slab, key), - None, - "stale set must no longer be visible" - ); - assert_eq!(get_val(&slab, key2), Some(0)); - key2 - }); - - let (lock, cvar) = &*pair; - - // wait for the second entry to be stored. the index we get from the - // other thread may become stale after a write. - let mut next = lock.lock().unwrap(); - while next.is_none() { - next = cvar.wait(next).unwrap(); - } - let key = next.unwrap(); - - // try to write to the index with our generation - slab.get(key).map(|val| val.set_readiness(key, |_| 2)); - - let key2 = t1.join().unwrap(); - // after the remove, we must not see the value written with the - // stale index either. - assert_eq!( - get_val(&slab, key), - None, - "stale set must no longer be visible" - ); - assert_eq!(get_val(&slab, key2), Some(0)); - }); -} - -fn get_val(slab: &Arc<Slab<ScheduledIo>>, address: Address) -> Option<usize> { - slab.get(address).and_then(|s| s.get_readiness(address)) -} - -fn store_val(slab: &Arc<Slab<ScheduledIo>>, readiness: usize) -> Address { - let key = slab.alloc().expect("allocate slot"); - - if let Some(slot) = slab.get(key) { - slot.set_readiness(key, |_| readiness) - .expect("generation should still be valid!"); - } else { - panic!("slab did not contain a value for {:?}", key); - } - - key -} - -fn store_when_free(slab: &Arc<Slab<ScheduledIo>>, readiness: usize) -> Address { - let key = loop { - if let Some(key) = slab.alloc() { - break key; - } - - thread::yield_now(); - }; - - if let Some(slot) = slab.get(key) { - slot.set_readiness(key, |_| readiness) - .expect("generation should still be valid!"); - } else { - panic!("slab did not contain a value for {:?}", key); - } - - key -} diff --git a/src/util/slab/tests/loom_stack.rs b/src/util/slab/tests/loom_stack.rs deleted file mode 100644 index 47ad46d..0000000 --- a/src/util/slab/tests/loom_stack.rs +++ /dev/null @@ -1,88 +0,0 @@ -use crate::util::slab::TransferStack; - -use loom::cell::UnsafeCell; -use loom::sync::Arc; -use loom::thread; - -#[test] -fn transfer_stack() { - loom::model(|| { - let causalities = [UnsafeCell::new(None), UnsafeCell::new(None)]; - let shared = Arc::new((causalities, TransferStack::new())); - let shared1 = shared.clone(); - let shared2 = shared.clone(); - - // Spawn two threads that both try to push to the stack. - let t1 = thread::spawn(move || { - let (causalities, stack) = &*shared1; - stack.push(0, |prev| { - causalities[0].with_mut(|c| unsafe { - *c = Some(prev); - }); - }); - }); - - let t2 = thread::spawn(move || { - let (causalities, stack) = &*shared2; - stack.push(1, |prev| { - causalities[1].with_mut(|c| unsafe { - *c = Some(prev); - }); - }); - }); - - let (causalities, stack) = &*shared; - - // Try to pop from the stack... - let mut idx = stack.pop_all(); - while idx == None { - idx = stack.pop_all(); - thread::yield_now(); - } - let idx = idx.unwrap(); - - let saw_both = causalities[idx].with(|val| { - let val = unsafe { *val }; - assert!( - val.is_some(), - "UnsafeCell write must happen-before index is pushed to the stack!", - ); - // were there two entries in the stack? if so, check that - // both saw a write. - if let Some(c) = causalities.get(val.unwrap()) { - c.with(|val| { - let val = unsafe { *val }; - assert!( - val.is_some(), - "UnsafeCell write must happen-before index is pushed to the stack!", - ); - }); - true - } else { - false - } - }); - - // We only saw one push. Ensure that the other push happens too. - if !saw_both { - // Try to pop from the stack... - let mut idx = stack.pop_all(); - while idx == None { - idx = stack.pop_all(); - thread::yield_now(); - } - let idx = idx.unwrap(); - - causalities[idx].with(|val| { - let val = unsafe { *val }; - assert!( - val.is_some(), - "UnsafeCell write must happen-before index is pushed to the stack!", - ); - }); - } - - t1.join().unwrap(); - t2.join().unwrap(); - }); -} diff --git a/src/util/slab/tests/mod.rs b/src/util/slab/tests/mod.rs deleted file mode 100644 index 7f79354..0000000 --- a/src/util/slab/tests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod loom_slab; -mod loom_stack; diff --git a/src/util/trace.rs b/src/util/trace.rs index d8c6120..18956a3 100644 --- a/src/util/trace.rs +++ b/src/util/trace.rs @@ -1,5 +1,5 @@ cfg_trace! { - cfg_rt_core! { + cfg_rt! { use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -47,7 +47,7 @@ cfg_trace! { } cfg_not_trace! { - cfg_rt_core! { + cfg_rt! { #[inline] pub(crate) fn task<F>(task: F, _: &'static str) -> F { // nop |