diff options
author | Joel Galenson <jgalenson@google.com> | 2021-08-17 08:33:38 -0700 |
---|---|---|
committer | Joel Galenson <jgalenson@google.com> | 2021-08-17 08:40:48 -0700 |
commit | 642961436a727d51930e5839e3dbfee04ba4af95 (patch) | |
tree | 9da006d6d1c0e4667e8d848673b13cc7d2bb62ca | |
parent | 1c33108b3901dd464f81acf08b5268ec294b3876 (diff) | |
download | tokio-642961436a727d51930e5839e3dbfee04ba4af95.tar.gz |
Upgrade rust/crates/tokio to 1.10.0
Test: make
Change-Id: I4ec984178af20297aae0ed51f0b1c6410876a51b
153 files changed, 8579 insertions, 3065 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index a170659..179bec2 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,5 +1,5 @@ { "git": { - "sha1": "7601dc6d2a5c2902e78e220f44646960f910f38f" + "sha1": "c0974bad94a06aaf04b62ba1397a8cfe5eb2fcb6" } } @@ -22,6 +22,8 @@ rust_library { name: "libtokio", host_supported: true, crate_name: "tokio", + cargo_env_compat: true, + cargo_pkg_version: "1.10.0", srcs: ["src/lib.rs"], edition: "2018", features: [ diff --git a/CHANGELOG.md b/CHANGELOG.md index 19d59cc..55b120f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,155 @@ +# 1.10.0 (August 12, 2021) + +### Added + + - io: add `(read|write)_f(32|64)[_le]` methods ([#4022]) + - io: add `fill_buf` and `consume` to `AsyncBufReadExt` ([#3991]) + - process: add `Child::raw_handle()` on windows ([#3998]) + +### Fixed + + - doc: fix non-doc builds with `--cfg docsrs` ([#4020]) + - io: flush eagerly in `io::copy` ([#4001]) + - runtime: a debug assert was sometimes triggered during shutdown ([#4005]) + - sync: use `spin_loop_hint` instead of `yield_now` in mpsc ([#4037]) + - tokio: the test-util feature depends on rt, sync, and time ([#4036]) + +### Changes + + - runtime: reorganize parts of the runtime ([#3979], [#4005]) + - signal: make windows docs for signal module show up on unix builds ([#3770]) + - task: quickly send task to heap on debug mode ([#4009]) + +### Documented + + - io: document cancellation safety of `AsyncBufReadExt` ([#3997]) + - sync: document when `watch::send` fails ([#4021]) + +[#3770]: https://github.com/tokio-rs/tokio/pull/3770 +[#3979]: https://github.com/tokio-rs/tokio/pull/3979 +[#3991]: https://github.com/tokio-rs/tokio/pull/3991 +[#3997]: https://github.com/tokio-rs/tokio/pull/3997 +[#3998]: https://github.com/tokio-rs/tokio/pull/3998 +[#4001]: https://github.com/tokio-rs/tokio/pull/4001 +[#4005]: https://github.com/tokio-rs/tokio/pull/4005 +[#4009]: https://github.com/tokio-rs/tokio/pull/4009 +[#4020]: https://github.com/tokio-rs/tokio/pull/4020 +[#4021]: https://github.com/tokio-rs/tokio/pull/4021 +[#4022]: https://github.com/tokio-rs/tokio/pull/4022 +[#4036]: https://github.com/tokio-rs/tokio/pull/4036 +[#4037]: https://github.com/tokio-rs/tokio/pull/4037 + +# 1.9.0 (July 22, 2021) + +### Added + + - net: allow customized I/O operations for `TcpStream` ([#3888]) + - sync: add getter for the mutex from a guard ([#3928]) + - task: expose nameable future for `TaskLocal::scope` ([#3273]) + +### Fixed + + - Fix leak if output of future panics on drop ([#3967]) + - Fix leak in `LocalSet` ([#3978]) + +### Changes + + - runtime: reorganize parts of the runtime ([#3909], [#3939], [#3950], [#3955], [#3980]) + - sync: clean up `OnceCell` ([#3945]) + - task: remove mutex in `JoinError` ([#3959]) + +[#3273]: https://github.com/tokio-rs/tokio/pull/3273 +[#3888]: https://github.com/tokio-rs/tokio/pull/3888 +[#3909]: https://github.com/tokio-rs/tokio/pull/3909 +[#3928]: https://github.com/tokio-rs/tokio/pull/3928 +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 +[#3939]: https://github.com/tokio-rs/tokio/pull/3939 +[#3945]: https://github.com/tokio-rs/tokio/pull/3945 +[#3950]: https://github.com/tokio-rs/tokio/pull/3950 +[#3955]: https://github.com/tokio-rs/tokio/pull/3955 +[#3959]: https://github.com/tokio-rs/tokio/pull/3959 +[#3967]: https://github.com/tokio-rs/tokio/pull/3967 +[#3978]: https://github.com/tokio-rs/tokio/pull/3978 +[#3980]: https://github.com/tokio-rs/tokio/pull/3980 + +# 1.8.3 (July 26, 2021) + +This release backports two fixes from 1.9.0 + +### Fixed + + - Fix leak if output of future panics on drop ([#3967]) + - Fix leak in `LocalSet` ([#3978]) + +[#3967]: https://github.com/tokio-rs/tokio/pull/3967 +[#3978]: https://github.com/tokio-rs/tokio/pull/3978 + +# 1.8.2 (July 19, 2021) + +Fixes a missed edge case from 1.8.1. + +### Fixed + +- runtime: drop canceled future on next poll (#3965) + +# 1.8.1 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + +# 1.8.0 (July 2, 2021) + +### Added + +- io: add `get_{ref,mut}` methods to `AsyncFdReadyGuard` and `AsyncFdReadyMutGuard` ([#3807]) +- io: efficient implementation of vectored writes for `BufWriter` ([#3163]) +- net: add ready/try methods to `NamedPipe{Client,Server}` ([#3866], [#3899]) +- sync: add `watch::Receiver::borrow_and_update` ([#3813]) +- sync: implement `From<T>` for `OnceCell<T>` ([#3877]) +- time: allow users to specify Interval behaviour when delayed ([#3721]) + +### Added (unstable) + +- rt: add `tokio::task::Builder` ([#3881]) + +### Fixed + +- net: handle HUP event with `UnixStream` ([#3898]) + +### Documented + +- doc: document cancellation safety ([#3900]) +- time: add wait alias to sleep ([#3897]) +- time: document auto-advancing behaviour of runtime ([#3763]) + +[#3163]: https://github.com/tokio-rs/tokio/pull/3163 +[#3721]: https://github.com/tokio-rs/tokio/pull/3721 +[#3763]: https://github.com/tokio-rs/tokio/pull/3763 +[#3807]: https://github.com/tokio-rs/tokio/pull/3807 +[#3813]: https://github.com/tokio-rs/tokio/pull/3813 +[#3866]: https://github.com/tokio-rs/tokio/pull/3866 +[#3877]: https://github.com/tokio-rs/tokio/pull/3877 +[#3881]: https://github.com/tokio-rs/tokio/pull/3881 +[#3897]: https://github.com/tokio-rs/tokio/pull/3897 +[#3898]: https://github.com/tokio-rs/tokio/pull/3898 +[#3899]: https://github.com/tokio-rs/tokio/pull/3899 +[#3900]: https://github.com/tokio-rs/tokio/pull/3900 + +# 1.7.2 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + # 1.7.1 (June 18, 2021) ### Fixed @@ -40,6 +192,16 @@ [#3840]: https://github.com/tokio-rs/tokio/pull/3840 [#3850]: https://github.com/tokio-rs/tokio/pull/3850 +# 1.6.3 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + # 1.6.2 (June 14, 2021) ### Fixes @@ -102,6 +264,14 @@ a kernel bug. ([#3803]) [#3775]: https://github.com/tokio-rs/tokio/pull/3775 [#3780]: https://github.com/tokio-rs/tokio/pull/3780 +# 1.5.1 (July 6, 2021) + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + # 1.5.0 (April 12, 2021) ### Added @@ -13,11 +13,11 @@ [package] edition = "2018" name = "tokio" -version = "1.7.1" +version = "1.10.0" authors = ["Tokio Contributors <team@tokio.rs>"] description = "An event-driven, non-blocking I/O platform for writing asynchronous I/O\nbacked applications.\n" homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio/1.7.1/tokio/" +documentation = "https://docs.rs/tokio/1.10.0/tokio/" readme = "README.md" keywords = ["io", "async", "non-blocking", "futures"] categories = ["asynchronous", "network-programming"] @@ -66,6 +66,9 @@ version = "0.3" version = "0.3.0" features = ["async-await"] +[dev-dependencies.mockall] +version = "0.10.2" + [dev-dependencies.proptest] version = "1" @@ -99,7 +102,7 @@ rt = [] rt-multi-thread = ["num_cpus", "rt"] signal = ["once_cell", "libc", "mio/os-poll", "mio/uds", "mio/os-util", "signal-hook-registry", "winapi/consoleapi"] sync = [] -test-util = [] +test-util = ["rt", "sync", "time"] time = [] [target."cfg(loom)".dev-dependencies.loom] version = "0.5" @@ -120,7 +123,7 @@ optional = true version = "0.2.42" [target."cfg(unix)".dev-dependencies.nix] -version = "0.19.0" +version = "0.22.0" [target."cfg(windows)".dependencies.winapi] version = "0.3.8" optional = true diff --git a/Cargo.toml.orig b/Cargo.toml.orig index fb9f546..da8e4b6 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -7,12 +7,12 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v1.0.x" git tag. -version = "1.7.1" +version = "1.10.0" edition = "2018" authors = ["Tokio Contributors <team@tokio.rs>"] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/1.7.1/tokio/" +documentation = "https://docs.rs/tokio/1.10.0/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -82,7 +82,7 @@ signal = [ "winapi/consoleapi", ] sync = [] -test-util = [] +test-util = ["rt", "sync", "time"] time = [] [dependencies] @@ -109,7 +109,7 @@ signal-hook-registry = { version = "1.1.1", optional = true } [target.'cfg(unix)'.dev-dependencies] libc = { version = "0.2.42" } -nix = { version = "0.19.0" } +nix = { version = "0.22.0" } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" @@ -123,6 +123,7 @@ version = "0.3.6" tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } futures = { version = "0.3.0", features = ["async-await"] } +mockall = "0.10.2" proptest = "1" rand = "0.8.0" tempfile = "3.1.0" @@ -7,13 +7,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/tokio/tokio-1.7.1.crate" + value: "https://static.crates.io/crates/tokio/tokio-1.10.0.crate" } - version: "1.7.1" + version: "1.10.0" license_type: NOTICE last_upgrade_date { year: 2021 - month: 6 - day: 22 + month: 8 + day: 17 } } @@ -50,7 +50,15 @@ an asynchronous application. ## Example -A basic TCP echo server with Tokio: +A basic TCP echo server with Tokio. + +Make sure you activated the full features of the tokio crate on Cargo.toml: + +```toml +[dependencies] +tokio = { version = "1.10.0", features = ["full"] } +``` +Then, on your main.rs: ```rust,no_run use tokio::net::TcpListener; @@ -58,7 +66,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { - let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + let listener = TcpListener::bind("127.0.0.1:8080").await?; loop { let (mut socket, _) = listener.accept().await?; @@ -132,7 +140,7 @@ several other libraries, including: * [`tower`]: A library of modular and reusable components for building robust networking clients and servers. -* [`tracing`]: A framework for application-level tracing and async-aware diagnostics. +* [`tracing`] (formerly `tokio-trace`): A framework for application-level tracing and async-aware diagnostics. * [`rdbc`]: A Rust database connectivity library for MySQL, Postgres and SQLite. @@ -155,9 +163,35 @@ several other libraries, including: ## Supported Rust Versions -Tokio is built against the latest stable release. The minimum supported version is 1.45. -The current Tokio version is not guaranteed to build on Rust versions earlier than the -minimum supported version. +Tokio is built against the latest stable release. The minimum supported version +is 1.45. The current Tokio version is not guaranteed to build on Rust versions +earlier than the minimum supported version. + +## Release schedule + +Tokio doesn't follow a fixed release schedule, but we typically make one to two +new minor releases each month. We make patch releases for bugfixes as necessary. + +## Bug patching policy + +For the purposes of making patch releases with bugfixes, we have designated +certain minor releases as LTS (long term support) releases. Whenever a bug +warrants a patch release with a fix for the bug, it will be backported and +released as a new patch release for each LTS minor version. Our current LTS +releases are: + + * `1.8.x` - LTS release until February 2022. + +Each LTS release will continue to receive backported fixes for at least half a +year. If you wish to use a fixed minor release in your project, we recommend +that you use an LTS release. + +To use a fixed minor version, you can specify the version with a tilde. For +example, to specify that you wish to use the newest `1.8.x` patch release, you +can use the following dependency specification: +```text +tokio = { version = "~1.8", features = [...] } +``` ## License diff --git a/docs/reactor-refactor.md b/docs/reactor-refactor.md index a0b5447..1c9ace1 100644 --- a/docs/reactor-refactor.md +++ b/docs/reactor-refactor.md @@ -228,7 +228,7 @@ It is only possible to implement `AsyncRead` and `AsyncWrite` for resource types themselves and not for `&Resource`. Implementing the traits for `&Resource` would permit concurrent operations to the resource. Because only a single waker is stored per direction, any concurrent usage would result in deadlocks. An -alterate implementation would call for a `Vec<Waker>` but this would result in +alternate implementation would call for a `Vec<Waker>` but this would result in memory leaks. ## Enabling reads and writes for `&TcpStream` @@ -268,9 +268,9 @@ select! { } ``` -It is also possible to sotre a `TcpStream` in an `Arc`. +It is also possible to store a `TcpStream` in an `Arc`. ```rust let arc_stream = Arc::new(my_tcp_stream); let n = arc_stream.by_ref().read(buf).await?; -```
\ No newline at end of file +``` diff --git a/patches/task_abort.patch b/patches/task_abort.patch new file mode 100644 index 0000000..df05ccb --- /dev/null +++ b/patches/task_abort.patch @@ -0,0 +1,20 @@ +diff --git a/tests/task_abort.rs b/tests/task_abort.rs +index cdaa405..ec0eed7 100644 +--- a/tests/task_abort.rs ++++ b/tests/task_abort.rs +@@ -180,6 +180,7 @@ fn test_abort_wakes_task_3964() { + /// Checks that aborting a task whose destructor panics does not allow the + /// panic to escape the task. + #[test] ++#[cfg(not(target_os = "android"))] + fn test_abort_task_that_panics_on_drop_contained() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + +@@ -204,6 +205,7 @@ fn test_abort_task_that_panics_on_drop_contained() { + + /// Checks that aborting a task whose destructor panics has the expected result. + #[test] ++#[cfg(not(target_os = "android"))] + fn test_abort_task_that_panics_on_drop_returned() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + diff --git a/src/fs/file.rs b/src/fs/file.rs index 5c06e73..5286e6c 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -3,7 +3,7 @@ //! [`File`]: File use self::State::*; -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use crate::io::blocking::Buf; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use crate::sync::Mutex; @@ -19,6 +19,19 @@ use std::task::Context; use std::task::Poll; use std::task::Poll::*; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(test)] +use super::mocks::MockFile as StdFile; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; +#[cfg(not(test))] +use std::fs::File as StdFile; + /// A reference to an open file on the filesystem. /// /// This is a specialized version of [`std::fs::File`][std] for usage from the @@ -78,7 +91,7 @@ use std::task::Poll::*; /// # } /// ``` pub struct File { - std: Arc<sys::File>, + std: Arc<StdFile>, inner: Mutex<Inner>, } @@ -96,7 +109,7 @@ struct Inner { #[derive(Debug)] enum State { Idle(Option<Buf>), - Busy(sys::Blocking<(Operation, Buf)>), + Busy(JoinHandle<(Operation, Buf)>), } #[derive(Debug)] @@ -142,7 +155,7 @@ impl File { /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn open(path: impl AsRef<Path>) -> io::Result<File> { let path = path.as_ref().to_owned(); - let std = asyncify(|| sys::File::open(path)).await?; + let std = asyncify(|| StdFile::open(path)).await?; Ok(File::from_std(std)) } @@ -182,7 +195,7 @@ impl File { /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn create(path: impl AsRef<Path>) -> io::Result<File> { let path = path.as_ref().to_owned(); - let std_file = asyncify(move || sys::File::create(path)).await?; + let std_file = asyncify(move || StdFile::create(path)).await?; Ok(File::from_std(std_file)) } @@ -199,7 +212,7 @@ impl File { /// let std_file = std::fs::File::open("foo.txt").unwrap(); /// let file = tokio::fs::File::from_std(std_file); /// ``` - pub fn from_std(std: sys::File) -> File { + pub fn from_std(std: StdFile) -> File { File { std: Arc::new(std), inner: Mutex::new(Inner { @@ -323,7 +336,7 @@ impl File { let std = self.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -409,7 +422,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub async fn into_std(mut self) -> sys::File { + pub async fn into_std(mut self) -> StdFile { self.inner.get_mut().complete_inflight().await; Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") } @@ -434,7 +447,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub fn try_into_std(mut self) -> Result<sys::File, Self> { + pub fn try_into_std(mut self) -> Result<StdFile, Self> { match Arc::try_unwrap(self.std) { Ok(file) => Ok(file), Err(std_file_arc) => { @@ -502,7 +515,7 @@ impl AsyncRead for File { buf.ensure_capacity_for(dst); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = buf.read_from(&mut &*std); (Operation::Read(res), buf) })); @@ -569,7 +582,7 @@ impl AsyncSeek for File { let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) })); @@ -636,7 +649,7 @@ impl AsyncWrite for File { let n = buf.copy_from(src); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -685,8 +698,8 @@ impl AsyncWrite for File { } } -impl From<sys::File> for File { - fn from(std: sys::File) -> Self { +impl From<StdFile> for File { + fn from(std: StdFile) -> Self { Self::from_std(std) } } @@ -709,7 +722,7 @@ impl std::os::unix::io::AsRawFd for File { #[cfg(unix)] impl std::os::unix::io::FromRawFd for File { unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self { - sys::File::from_raw_fd(fd).into() + StdFile::from_raw_fd(fd).into() } } @@ -723,7 +736,7 @@ impl std::os::windows::io::AsRawHandle for File { #[cfg(windows)] impl std::os::windows::io::FromRawHandle for File { unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self { - sys::File::from_raw_handle(handle).into() + StdFile::from_raw_handle(handle).into() } } @@ -756,3 +769,6 @@ impl Inner { } } } + +#[cfg(test)] +mod tests; diff --git a/tests/fs_file_mocked.rs b/src/fs/file/tests.rs index 7771532..28b5ffe 100644 --- a/tests/fs_file_mocked.rs +++ b/src/fs/file/tests.rs @@ -1,80 +1,21 @@ -#![warn(rust_2018_idioms)] -#![cfg(feature = "full")] - -macro_rules! ready { - ($e:expr $(,)?) => { - match $e { - std::task::Poll::Ready(t) => t, - std::task::Poll::Pending => return std::task::Poll::Pending, - } - }; -} - -#[macro_export] -macro_rules! cfg_fs { - ($($item:item)*) => { $($item)* } -} - -#[macro_export] -macro_rules! cfg_io_std { - ($($item:item)*) => { $($item)* } -} - -use futures::future; - -// Load source -#[allow(warnings)] -#[path = "../src/fs/file.rs"] -mod file; -use file::File; - -#[allow(warnings)] -#[path = "../src/io/blocking.rs"] -mod blocking; - -// Load mocked types -mod support { - pub(crate) mod mock_file; - pub(crate) mod mock_pool; -} -pub(crate) use support::mock_pool as pool; - -// Place them where the source expects them -pub(crate) mod io { - pub(crate) use tokio::io::*; - - pub(crate) use crate::blocking; - - pub(crate) mod sys { - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } -} -pub(crate) mod fs { - pub(crate) mod sys { - pub(crate) use crate::support::mock_file::File; - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } - - pub(crate) use crate::support::mock_pool::asyncify; -} -pub(crate) mod sync { - pub(crate) use tokio::sync::Mutex; -} -use fs::sys; - -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; -use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; - -use std::io::SeekFrom; +use super::*; +use crate::{ + fs::mocks::*, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; +use mockall::{predicate::eq, Sequence}; +use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task}; const HELLO: &[u8] = b"hello world..."; const FOO: &[u8] = b"foo bar baz..."; #[test] fn open_read() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -83,12 +24,10 @@ fn open_read() { assert_eq!(0, pool::len()); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); let n = assert_ready_ok!(t.poll()); @@ -98,9 +37,11 @@ fn open_read() { #[test] fn read_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -120,8 +61,11 @@ fn read_twice_before_dispatch() { #[test] fn read_with_smaller_buf() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -153,8 +97,22 @@ fn read_with_smaller_buf() { #[test] fn read_with_bigger_buf() { - let (mock, file) = sys::File::mock(); - mock.read(&HELLO[..4]).read(&HELLO[4..]); + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..4].copy_from_slice(&HELLO[..4]); + Ok(4) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len() - 4].copy_from_slice(&HELLO[4..]); + Ok(HELLO.len() - 4) + }); let mut file = File::from_std(file); @@ -194,8 +152,19 @@ fn read_with_bigger_buf() { #[test] fn read_err_then_read_success() { - let (mock, file) = sys::File::mock(); - mock.read_err().read(&HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -225,8 +194,11 @@ fn read_err_then_read_success() { #[test] fn open_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -235,12 +207,10 @@ fn open_write() { assert_eq!(0, pool::len()); assert_ready_ok!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(!t.is_woken()); let mut t = task::spawn(file.flush()); @@ -249,7 +219,7 @@ fn open_write() { #[test] fn flush_while_idle() { - let (_mock, file) = sys::File::mock(); + let file = MockFile::default(); let mut file = File::from_std(file); @@ -271,13 +241,42 @@ fn read_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.read(&data[0..chunk_a]) - .read(&data[chunk_a..chunk_b]) - .read(&data[chunk_b..chunk_c]) - .read(&data[chunk_c..]); - + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[0..chunk_a].copy_from_slice(&d0[0..chunk_a]); + Ok(chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d1[chunk_a..chunk_b]); + Ok(chunk_b - chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d2[chunk_b..chunk_c]); + Ok(chunk_c - chunk_b) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a - 1].copy_from_slice(&d3[chunk_c..]); + Ok(chunk_a - 1) + }); let mut file = File::from_std(file); let mut actual = vec![0; chunk_d]; @@ -296,8 +295,7 @@ fn read_with_buffer_larger_than_max() { pos += n; } - assert_eq!(mock.remaining(), 0); - assert_eq!(data, &actual[..data.len()]); + assert_eq!(&data[..], &actual[..data.len()]); } #[test] @@ -314,12 +312,34 @@ fn write_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.write(&data[0..chunk_a]) - .write(&data[chunk_a..chunk_b]) - .write(&data[chunk_b..chunk_c]) - .write(&data[chunk_c..]); + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d0[0..chunk_a]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d1[chunk_a..chunk_b]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d2[chunk_b..chunk_c]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d3[chunk_c..chunk_d - 1]) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -344,14 +364,22 @@ fn write_with_buffer_larger_than_max() { } pool::run_one(); - - assert_eq!(mock.remaining(), 0); } #[test] fn write_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -380,10 +408,24 @@ fn write_twice_before_dispatch() { #[test] fn incomplete_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -406,8 +448,25 @@ fn incomplete_read_followed_by_write() { #[test] fn incomplete_partial_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO).seek_current_ok(-10, 0).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-10))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -433,10 +492,25 @@ fn incomplete_partial_read_followed_by_write() { #[test] fn incomplete_read_followed_by_flush() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -458,8 +532,18 @@ fn incomplete_read_followed_by_flush() { #[test] fn incomplete_flush_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -484,8 +568,10 @@ fn incomplete_flush_followed_by_write() { #[test] fn read_err() { - let (mock, file) = sys::File::mock(); - mock.read_err(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -502,8 +588,10 @@ fn read_err() { #[test] fn write_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err(); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -518,8 +606,19 @@ fn write_write_err() { #[test] fn write_read_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -541,8 +640,19 @@ fn write_read_write_err() { #[test] fn write_read_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -564,8 +674,17 @@ fn write_read_flush_err() { #[test] fn write_seek_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -587,8 +706,17 @@ fn write_seek_write_err() { #[test] fn write_seek_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -610,8 +738,14 @@ fn write_seek_flush_err() { #[test] fn sync_all_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -635,8 +769,16 @@ fn sync_all_ordered_after_write() { #[test] fn sync_all_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -660,8 +802,14 @@ fn sync_all_err_ordered_after_write() { #[test] fn sync_data_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -685,8 +833,16 @@ fn sync_data_ordered_after_write() { #[test] fn sync_data_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -710,17 +866,15 @@ fn sync_data_err_ordered_after_write() { #[test] fn open_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.set_len(123); + let mut file = MockFile::default(); + file.expect_set_len().with(eq(123)).returning(|_| Ok(())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_ok!(t.poll()); @@ -728,17 +882,17 @@ fn open_set_len_ok() { #[test] fn open_set_len_err() { - let (mock, file) = sys::File::mock(); - mock.set_len_err(123); + let mut file = MockFile::default(); + file.expect_set_len() + .with(eq(123)) + .returning(|_| Err(io::ErrorKind::Other.into())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_err!(t.poll()); @@ -746,11 +900,32 @@ fn open_set_len_err() { #[test] fn partial_read_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-14, 0) - .set_len(123) - .read(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_set_len() + .once() + .in_sequence(&mut seq) + .with(eq(123)) + .returning(|_| Ok(())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..FOO.len()].copy_from_slice(FOO); + Ok(FOO.len()) + }); let mut buf = [0; 32]; let mut file = File::from_std(file); diff --git a/src/fs/mocks.rs b/src/fs/mocks.rs new file mode 100644 index 0000000..68ef4f3 --- /dev/null +++ b/src/fs/mocks.rs @@ -0,0 +1,136 @@ +//! Mock version of std::fs::File; +use mockall::mock; + +use crate::sync::oneshot; +use std::{ + cell::RefCell, + collections::VecDeque, + fs::{Metadata, Permissions}, + future::Future, + io::{self, Read, Seek, SeekFrom, Write}, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +mock! { + #[derive(Debug)] + pub File { + pub fn create(pb: PathBuf) -> io::Result<Self>; + // These inner_ methods exist because std::fs::File has two + // implementations for each of these methods: one on "&mut self" and + // one on "&&self". Defining both of those in terms of an inner_ method + // allows us to specify the expectation the same way, regardless of + // which method is used. + pub fn inner_flush(&self) -> io::Result<()>; + pub fn inner_read(&self, dst: &mut [u8]) -> io::Result<usize>; + pub fn inner_seek(&self, pos: SeekFrom) -> io::Result<u64>; + pub fn inner_write(&self, src: &[u8]) -> io::Result<usize>; + pub fn metadata(&self) -> io::Result<Metadata>; + pub fn open(pb: PathBuf) -> io::Result<Self>; + pub fn set_len(&self, size: u64) -> io::Result<()>; + pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>; + pub fn sync_all(&self) -> io::Result<()>; + pub fn sync_data(&self) -> io::Result<()>; + pub fn try_clone(&self) -> io::Result<Self>; + } + #[cfg(windows)] + impl std::os::windows::io::AsRawHandle for File { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle; + } + #[cfg(windows)] + impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(h: std::os::windows::io::RawHandle) -> Self; + } + #[cfg(unix)] + impl std::os::unix::io::AsRawFd for File { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd; + } + + #[cfg(unix)] + impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(h: std::os::unix::io::RawFd) -> Self; + } +} + +impl Read for MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Read for &'_ MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Seek for &'_ MockFile { + fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> { + self.inner_seek(pos) + } +} + +impl Write for &'_ MockFile { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + self.inner_write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner_flush() + } +} + +thread_local! { + static QUEUE: RefCell<VecDeque<Box<dyn FnOnce() + Send>>> = RefCell::new(VecDeque::new()) +} + +#[derive(Debug)] +pub(super) struct JoinHandle<T> { + rx: oneshot::Receiver<T>, +} + +pub(super) fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + JoinHandle { rx } +} + +impl<T> Future for JoinHandle<T> { + type Output = Result<T, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::task::Poll::*; + + match Pin::new(&mut self.rx).poll(cx) { + Ready(Ok(v)) => Ready(Ok(v)), + Ready(Err(e)) => panic!("error = {:?}", e), + Pending => Pending, + } + } +} + +pub(super) mod pool { + use super::*; + + pub(in super::super) fn len() -> usize { + QUEUE.with(|cell| cell.borrow().len()) + } + + pub(in super::super) fn run_one() { + let task = QUEUE + .with(|cell| cell.borrow_mut().pop_front()) + .expect("expected task to run, but none ready"); + + task(); + } +} diff --git a/src/fs/mod.rs b/src/fs/mod.rs index d4f0074..ca0264b 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -84,6 +84,9 @@ pub use self::write::write; mod copy; pub use self::copy::copy; +#[cfg(test)] +mod mocks; + feature! { #![unix] @@ -103,12 +106,17 @@ feature! { use std::io; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(test)] +use mocks::spawn_blocking; + pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> where F: FnOnce() -> io::Result<T> + Send + 'static, T: Send + 'static, { - match sys::run(f).await { + match spawn_blocking(f).await { Ok(res) => res, Err(_) => Err(io::Error::new( io::ErrorKind::Other, @@ -116,12 +124,3 @@ where )), } } - -/// Types in this module can be mocked out in tests. -mod sys { - pub(crate) use std::fs::File; - - // TODO: don't rename - pub(crate) use crate::blocking::spawn_blocking as run; - pub(crate) use crate::blocking::JoinHandle as Blocking; -} diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs index fa37a60..3e73529 100644 --- a/src/fs/open_options.rs +++ b/src/fs/open_options.rs @@ -3,6 +3,13 @@ use crate::fs::{asyncify, File}; use std::io; use std::path::Path; +#[cfg(test)] +mod mock_open_options; +#[cfg(test)] +use mock_open_options::MockOpenOptions as StdOpenOptions; +#[cfg(not(test))] +use std::fs::OpenOptions as StdOpenOptions; + /// Options and flags which can be used to configure how a file is opened. /// /// This builder exposes the ability to configure how a [`File`] is opened and @@ -69,7 +76,7 @@ use std::path::Path; /// } /// ``` #[derive(Clone, Debug)] -pub struct OpenOptions(std::fs::OpenOptions); +pub struct OpenOptions(StdOpenOptions); impl OpenOptions { /// Creates a blank new set of options ready for configuration. @@ -89,7 +96,7 @@ impl OpenOptions { /// let future = options.read(true).open("foo.txt"); /// ``` pub fn new() -> OpenOptions { - OpenOptions(std::fs::OpenOptions::new()) + OpenOptions(StdOpenOptions::new()) } /// Sets the option for read access. @@ -384,7 +391,7 @@ impl OpenOptions { } /// Returns a mutable reference to the underlying `std::fs::OpenOptions` - pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { + pub(super) fn as_inner_mut(&mut self) -> &mut StdOpenOptions { &mut self.0 } } @@ -645,8 +652,8 @@ feature! { } } -impl From<std::fs::OpenOptions> for OpenOptions { - fn from(options: std::fs::OpenOptions) -> OpenOptions { +impl From<StdOpenOptions> for OpenOptions { + fn from(options: StdOpenOptions) -> OpenOptions { OpenOptions(options) } } diff --git a/src/fs/open_options/mock_open_options.rs b/src/fs/open_options/mock_open_options.rs new file mode 100644 index 0000000..cbbda0e --- /dev/null +++ b/src/fs/open_options/mock_open_options.rs @@ -0,0 +1,38 @@ +//! Mock version of std::fs::OpenOptions; +use mockall::mock; + +use crate::fs::mocks::MockFile; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(windows)] +use std::os::windows::fs::OpenOptionsExt; +use std::{io, path::Path}; + +mock! { + #[derive(Debug)] + pub OpenOptions { + pub fn append(&mut self, append: bool) -> &mut Self; + pub fn create(&mut self, create: bool) -> &mut Self; + pub fn create_new(&mut self, create_new: bool) -> &mut Self; + pub fn open<P: AsRef<Path> + 'static>(&self, path: P) -> io::Result<MockFile>; + pub fn read(&mut self, read: bool) -> &mut Self; + pub fn truncate(&mut self, truncate: bool) -> &mut Self; + pub fn write(&mut self, write: bool) -> &mut Self; + } + impl Clone for OpenOptions { + fn clone(&self) -> Self; + } + #[cfg(unix)] + impl OpenOptionsExt for OpenOptions { + fn custom_flags(&mut self, flags: i32) -> &mut Self; + fn mode(&mut self, mode: u32) -> &mut Self; + } + #[cfg(windows)] + impl OpenOptionsExt for OpenOptions { + fn access_mode(&mut self, access: u32) -> &mut Self; + fn share_mode(&mut self, val: u32) -> &mut Self; + fn custom_flags(&mut self, flags: u32) -> &mut Self; + fn attributes(&mut self, val: u32) -> &mut Self; + fn security_qos_flags(&mut self, flags: u32) -> &mut Self; + } +} diff --git a/src/fs/read.rs b/src/fs/read.rs index 2d80eb5..ada5ba3 100644 --- a/src/fs/read.rs +++ b/src/fs/read.rs @@ -13,8 +13,12 @@ use std::{io, path::Path}; /// buffer based on the file size when available, so it is generally faster than /// reading into a vector created with `Vec::new()`. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// /// [`File::open`]: super::File::open /// [`read_to_end`]: crate::io::AsyncReadExt::read_to_end +/// [`spawn_blocking`]: crate::task::spawn_blocking /// /// # Errors /// diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs index aedaf7b..514d59c 100644 --- a/src/fs/read_dir.rs +++ b/src/fs/read_dir.rs @@ -1,4 +1,4 @@ -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use std::ffi::OsString; use std::fs::{FileType, Metadata}; @@ -10,9 +10,23 @@ use std::sync::Arc; use std::task::Context; use std::task::Poll; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; + /// Returns a stream over the entries within a directory. /// /// This is an async version of [`std::fs::read_dir`](std::fs::read_dir) +/// +/// This operation is implemented by running the equivalent blocking +/// operation on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking pub async fn read_dir(path: impl AsRef<Path>) -> io::Result<ReadDir> { let path = path.as_ref().to_owned(); let std = asyncify(|| std::fs::read_dir(path)).await?; @@ -45,11 +59,15 @@ pub struct ReadDir(State); #[derive(Debug)] enum State { Idle(Option<std::fs::ReadDir>), - Pending(sys::Blocking<(Option<io::Result<std::fs::DirEntry>>, std::fs::ReadDir)>), + Pending(JoinHandle<(Option<io::Result<std::fs::DirEntry>>, std::fs::ReadDir)>), } impl ReadDir { /// Returns the next entry in the directory stream. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. pub async fn next_entry(&mut self) -> io::Result<Option<DirEntry>> { use crate::future::poll_fn; poll_fn(|cx| self.poll_next_entry(cx)).await @@ -79,7 +97,7 @@ impl ReadDir { State::Idle(ref mut std) => { let mut std = std.take().unwrap(); - self.0 = State::Pending(sys::run(move || { + self.0 = State::Pending(spawn_blocking(move || { let ret = std.next(); (ret, std) })); diff --git a/src/fs/read_to_string.rs b/src/fs/read_to_string.rs index 4f37986..26228d9 100644 --- a/src/fs/read_to_string.rs +++ b/src/fs/read_to_string.rs @@ -7,6 +7,10 @@ use std::{io, path::Path}; /// /// This is the async equivalent of [`std::fs::read_to_string`][std]. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking /// [std]: fn@std::fs::read_to_string /// /// # Examples diff --git a/src/fs/write.rs b/src/fs/write.rs index 0ed9082..28606fb 100644 --- a/src/fs/write.rs +++ b/src/fs/write.rs @@ -7,6 +7,10 @@ use std::{io, path::Path}; /// /// This is the async equivalent of [`std::fs::write`][std]. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking /// [std]: fn@std::fs::write /// /// # Examples diff --git a/src/io/async_fd.rs b/src/io/async_fd.rs index 5a68d30..fa5bec5 100644 --- a/src/io/async_fd.rs +++ b/src/io/async_fd.rs @@ -540,6 +540,16 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { result => Ok(result), } } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &Inner { + self.get_ref().get_ref() + } } impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { @@ -601,6 +611,26 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { result => Ok(result), } } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &AsyncFd<Inner> { + self.async_fd + } + + /// Returns a mutable reference to the inner [`AsyncFd`]. + pub fn get_mut(&mut self) -> &mut AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &Inner { + self.get_ref().get_ref() + } + + /// Returns a mutable reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner_mut(&mut self) -> &mut Inner { + self.get_mut().get_mut() + } } impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyGuard<'a, T> { diff --git a/src/io/driver/interest.rs b/src/io/driver/interest.rs index 9eead08..36951cf 100644 --- a/src/io/driver/interest.rs +++ b/src/io/driver/interest.rs @@ -58,7 +58,7 @@ impl Interest { self.0.is_writable() } - /// Add together two `Interst` values. + /// Add together two `Interest` values. /// /// This function works from a `const` context. /// diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs index 52451c6..3aa0cfb 100644 --- a/src/io/driver/mod.rs +++ b/src/io/driver/mod.rs @@ -96,7 +96,7 @@ const ADDRESS: bit::Pack = bit::Pack::least_significant(24); // // The generation prevents a race condition where a slab slot is reused for a // new socket while the I/O driver is about to apply a readiness event. The -// generaton value is checked when setting new readiness. If the generation do +// generation value is checked when setting new readiness. If the generation do // not match, then the readiness event is discarded. const GENERATION: bit::Pack = ADDRESS.then(7); diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs index 2626b40..5178010 100644 --- a/src/io/driver/scheduled_io.rs +++ b/src/io/driver/scheduled_io.rs @@ -84,9 +84,9 @@ cfg_io_readiness! { // The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness. // -// | reserved | generation | driver tick | readinesss | -// |----------+------------+--------------+------------| -// | 1 bit | 7 bits + 8 bits + 16 bits | +// | reserved | generation | driver tick | readiness | +// |----------+------------+--------------+-----------| +// | 1 bit | 7 bits + 8 bits + 16 bits | const READINESS: bit::Pack = bit::Pack::least_significant(16); diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs index a31e6db..9872574 100644 --- a/src/io/poll_evented.rs +++ b/src/io/poll_evented.rs @@ -40,9 +40,8 @@ cfg_io_driver! { /// [`poll_read_ready`] again will also indicate read readiness. /// /// When the operation is attempted and is unable to succeed due to the I/O - /// resource not being ready, the caller must call `clear_read_ready` or - /// `clear_write_ready`. This clears the readiness state until a new - /// readiness event is received. + /// resource not being ready, the caller must call `clear_readiness`. + /// This clears the readiness state until a new readiness event is received. /// /// This allows the caller to implement additional functions. For example, /// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs index 38e857d..ad58cbe 100644 --- a/src/io/read_buf.rs +++ b/src/io/read_buf.rs @@ -45,7 +45,7 @@ impl<'a> ReadBuf<'a> { /// Creates a new `ReadBuf` from a fully uninitialized buffer. /// - /// Use `assume_init` if part of the buffer is known to be already inintialized. + /// Use `assume_init` if part of the buffer is known to be already initialized. #[inline] pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> { ReadBuf { @@ -85,7 +85,7 @@ impl<'a> ReadBuf<'a> { #[inline] pub fn take(&mut self, n: usize) -> ReadBuf<'_> { let max = std::cmp::min(self.remaining(), n); - // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + // Safety: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } } @@ -217,7 +217,7 @@ impl<'a> ReadBuf<'a> { /// /// # Panics /// - /// Panics if the filled region of the buffer would become larger than the intialized region. + /// Panics if the filled region of the buffer would become larger than the initialized region. #[inline] pub fn set_filled(&mut self, n: usize) { assert!( diff --git a/src/io/split.rs b/src/io/split.rs index 732eb3b..f35273f 100644 --- a/src/io/split.rs +++ b/src/io/split.rs @@ -63,7 +63,7 @@ impl<T> ReadHalf<T> { /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same /// stream. pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { - other.is_pair_of(&self) + other.is_pair_of(self) } /// Reunites with a previously split `WriteHalf`. diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs index d21c842..56c4520 100644 --- a/src/io/stdio_common.rs +++ b/src/io/stdio_common.rs @@ -52,10 +52,10 @@ where buf = &buf[..crate::io::blocking::MAX_BUF]; - // Now there are two possibilites. + // Now there are two possibilities. // If caller gave is binary buffer, we **should not** shrink it // anymore, because excessive shrinking hits performance. - // If caller gave as binary buffer, we **must** additionaly + // If caller gave as binary buffer, we **must** additionally // shrink it to strip incomplete char at the end of buffer. // that's why check we will perform now is allowed to have // false-positive. diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 233ac31..b241e35 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -1,3 +1,4 @@ +use crate::io::util::fill_buf::{fill_buf, FillBuf}; use crate::io::util::lines::{lines, Lines}; use crate::io::util::read_line::{read_line, ReadLine}; use crate::io::util::read_until::{read_until, ReadUntil}; @@ -36,6 +37,18 @@ cfg_io_util! { /// [`fill_buf`]: AsyncBufRead::poll_fill_buf /// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted /// + /// # Cancel safety + /// + /// If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then some data may have been partially read. Any + /// partially read bytes are appended to `buf`, and the method can be + /// called again to continue reading until `byte`. + /// + /// This method returns the total number of bytes read. If you cancel + /// the call to `read_until` and then call it again to continue reading, + /// the counter is reset. + /// /// # Examples /// /// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In @@ -114,6 +127,30 @@ cfg_io_util! { /// /// [`read_until`]: AsyncBufReadExt::read_until /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may have been partially + /// read, and this data is lost. There are no guarantees regarding the + /// contents of `buf` when the call is cancelled. The current + /// implementation replaces `buf` with the empty string, but this may + /// change in the future. + /// + /// This function does not behave like [`read_until`] because of the + /// requirement that a string contains only valid utf-8. If you need a + /// cancellation safe `read_line`, there are three options: + /// + /// * Call [`read_until`] with a newline character and manually perform the utf-8 check. + /// * The stream returned by [`lines`] has a cancellation safe + /// [`next_line`] method. + /// * Use [`tokio_util::codec::LinesCodec`][LinesCodec]. + /// + /// [LinesCodec]: https://docs.rs/tokio-util/0.6/tokio_util/codec/struct.LinesCodec.html + /// [`read_until`]: Self::read_until + /// [`lines`]: Self::lines + /// [`next_line`]: crate::io::Lines::next_line + /// /// # Examples /// /// [`std::io::Cursor`][`Cursor`] is a type that implements @@ -173,10 +210,11 @@ cfg_io_util! { /// [`BufRead::split`](std::io::BufRead::split). /// /// The stream returned from this function will yield instances of - /// [`io::Result`]`<`[`Vec<u8>`]`>`. Each vector returned will *not* have + /// [`io::Result`]`<`[`Option`]`<`[`Vec<u8>`]`>>`. Each vector returned will *not* have /// the delimiter byte at the end. /// /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option /// [`Vec<u8>`]: std::vec::Vec /// /// # Errors @@ -206,14 +244,68 @@ cfg_io_util! { split(self, byte) } + /// Returns the contents of the internal buffer, filling it with more + /// data from the inner reader if it is empty. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`consume`] method to function properly. When calling this method, + /// none of the contents will be "read" in the sense that later calling + /// `read` may return the same contents. As such, [`consume`] must be + /// called with the number of bytes that are consumed from this buffer + /// to ensure that the bytes are never returned twice. + /// + /// An empty buffer returned indicates that the stream has reached EOF. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn fill_buf(&mut self) -> io::Result<&[u8]>; + /// ``` + /// + /// # Errors + /// + /// This function will return an I/O error if the underlying reader was + /// read, but returned an error. + /// + /// [`consume`]: crate::io::AsyncBufReadExt::consume + fn fill_buf(&mut self) -> FillBuf<'_, Self> + where + Self: Unpin, + { + fill_buf(self) + } + + /// Tells this buffer that `amt` bytes have been consumed from the + /// buffer, so they should no longer be returned in calls to [`read`]. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`fill_buf`] method to function properly. This function does not + /// perform any I/O, it simply informs this object that some amount of + /// its buffer, returned from [`fill_buf`], has been consumed and should + /// no longer be returned. As such, this function may do odd things if + /// [`fill_buf`] isn't called before calling it. + /// + /// The `amt` must be less than the number of bytes in the buffer + /// returned by [`fill_buf`]. + /// + /// [`read`]: crate::io::AsyncReadExt::read + /// [`fill_buf`]: crate::io::AsyncBufReadExt::fill_buf + fn consume(&mut self, amt: usize) + where + Self: Unpin, + { + std::pin::Pin::new(self).consume(amt) + } + /// Returns a stream over the lines of this reader. /// This method is the async equivalent to [`BufRead::lines`](std::io::BufRead::lines). /// /// The stream returned from this function will yield instances of - /// [`io::Result`]`<`[`String`]`>`. Each string returned will *not* have a newline + /// [`io::Result`]`<`[`Option`]`<`[`String`]`>>`. Each string returned will *not* have a newline /// byte (the 0xA byte) or CRLF (0xD, 0xA bytes) at the end. /// /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option /// [`String`]: String /// /// # Errors diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index 878676f..df5445c 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -2,6 +2,7 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; +use crate::io::util::read_int::{ReadF32, ReadF32Le, ReadF64, ReadF64Le}; use crate::io::util::read_int::{ ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, }; @@ -105,8 +106,8 @@ cfg_io_util! { /// async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>; /// ``` /// - /// This function does not provide any guarantees about whether it - /// completes immediately or asynchronously + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. /// /// # Return /// @@ -138,6 +139,12 @@ cfg_io_util! { /// variant will be returned. If an error is returned then it must be /// guaranteed that no bytes were read. /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// /// # Examples /// /// [`File`][crate::fs::File]s implement `Read`: @@ -177,8 +184,8 @@ cfg_io_util! { /// Usually, only a single `read` syscall is issued, even if there is /// more space in the supplied buffer. /// - /// This function does not provide any guarantees about whether it - /// completes immediately or asynchronously + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. /// /// # Return /// @@ -197,6 +204,12 @@ cfg_io_util! { /// variant will be returned. If an error is returned then it must be /// guaranteed that no bytes were read. /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// /// # Examples /// /// [`File`] implements `Read` and [`BytesMut`] implements [`BufMut`]: @@ -261,6 +274,13 @@ cfg_io_util! { /// it has read, but it will never read more than would be necessary to /// completely fill the buffer. /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may already have been + /// read into `buf`. + /// /// # Examples /// /// [`File`][crate::fs::File]s implement `Read`: @@ -672,6 +692,82 @@ cfg_io_util! { /// ``` fn read_i128(&mut self) -> ReadI128; + /// Reads an 32-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0x7f, 0xff, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32(&mut self) -> ReadF32; + + /// Reads an 64-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64(&mut self) -> ReadF64; + /// Reads an unsigned 16-bit integer in little-endian order from the /// underlying reader. /// @@ -978,6 +1074,82 @@ cfg_io_util! { /// } /// ``` fn read_i128_le(&mut self) -> ReadI128Le; + + /// Reads an 32-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32_le(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7f, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32_le(&mut self) -> ReadF32Le; + + /// Reads an 64-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64_le(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64_le(&mut self) -> ReadF64Le; } /// Reads all bytes until EOF in this source, placing them into `buf`. diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index 4651a99..a1f77f8 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -4,6 +4,7 @@ use crate::io::util::write::{write, Write}; use crate::io::util::write_all::{write_all, WriteAll}; use crate::io::util::write_all_buf::{write_all_buf, WriteAllBuf}; use crate::io::util::write_buf::{write_buf, WriteBuf}; +use crate::io::util::write_int::{WriteF32, WriteF32Le, WriteF64, WriteF64Le}; use crate::io::util::write_int::{ WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, WriteI8, @@ -97,6 +98,13 @@ cfg_io_util! { /// It is **not** considered an error if the entire buffer could not be /// written to this writer. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// /// ```no_run @@ -129,6 +137,13 @@ cfg_io_util! { /// /// See [`AsyncWrite::poll_write_vectored`] for more details. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// /// ```no_run @@ -195,6 +210,13 @@ cfg_io_util! { /// It is **not** considered an error if the entire buffer could not be /// written to this writer. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: @@ -243,6 +265,7 @@ cfg_io_util! { /// while buf.has_remaining() { /// self.write_buf(&mut buf).await?; /// } + /// Ok(()) /// } /// ``` /// @@ -254,6 +277,15 @@ cfg_io_util! { /// The buffer is advanced after each chunk is successfully written. After failure, /// `src.chunk()` will return the chunk that failed to write. /// + /// # Cancel safety + /// + /// If `write_all_buf` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then the data in the provided buffer may have been + /// partially written. However, it is guaranteed that the provided + /// buffer has been [advanced] by the amount of bytes that have been + /// partially written. + /// /// # Examples /// /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: @@ -261,6 +293,7 @@ cfg_io_util! { /// [`File`]: crate::fs::File /// [`Buf`]: bytes::Buf /// [`Cursor`]: std::io::Cursor + /// [advanced]: bytes::Buf::advance /// /// ```no_run /// use tokio::io::{self, AsyncWriteExt}; @@ -300,6 +333,14 @@ cfg_io_util! { /// has been successfully written or such an error occurs. The first /// error generated from this method will be returned. /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If it is used as the event + /// in a [`tokio::select!`](crate::select) statement and some other + /// branch completes first, then the provided buffer may have been + /// partially written, but future calls to `write_all` will start over + /// from the beginning of the buffer. + /// /// # Errors /// /// This function will return the first error that [`write`] returns. @@ -710,6 +751,81 @@ cfg_io_util! { /// ``` fn write_i128(&mut self, n: i128) -> WriteI128; + /// Writes an 32-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0x7f, 0xff, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32(&mut self, n: f32) -> WriteF32; + + /// Writes an 64-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64(&mut self, n: f64) -> WriteF64; /// Writes an unsigned 16-bit integer in little-endian order to the /// underlying writer. @@ -1018,6 +1134,82 @@ cfg_io_util! { /// } /// ``` fn write_i128_le(&mut self, n: i128) -> WriteI128Le; + + /// Writes an 32-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32_le(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32_le(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0xff, 0x7f, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32_le(&mut self, n: f32) -> WriteF32Le; + + /// Writes an 64-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64_le(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64_le(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64_le(&mut self, n: f64) -> WriteF64Le; } /// Flushes this output stream, ensuring that all intermediately buffered diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index c4d6842..7cfd46c 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -2,7 +2,7 @@ use crate::io::util::DEFAULT_BUF_SIZE; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io::{self, SeekFrom}; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt, mem}; @@ -268,6 +268,18 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { self.get_pin_mut().poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.get_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.get_ref().is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.get_pin_mut().poll_flush(cx) } diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index ff3d9db..595c142 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -2,7 +2,7 @@ use crate::io::util::{BufReader, BufWriter}; use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io::{self, SeekFrom}; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -127,6 +127,18 @@ impl<RW: AsyncRead + AsyncWrite> AsyncWrite for BufStream<RW> { self.project().inner.poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.project().inner.poll_flush(cx) } diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs index 4e8e493..8dd1bba 100644 --- a/src/io/util/buf_writer.rs +++ b/src/io/util/buf_writer.rs @@ -3,7 +3,7 @@ use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, SeekFrom, Write}; +use std::io::{self, IoSlice, SeekFrom, Write}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -133,6 +133,72 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { } } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + if self.inner.is_write_vectored() { + let total_len = bufs + .iter() + .fold(0usize, |acc, b| acc.saturating_add(b.len())); + if total_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if total_len >= me.buf.capacity() { + // It's more efficient to pass the slices directly to the + // underlying writer than to buffer them. + // The case when the total_len calculation saturates at + // usize::MAX is also handled here. + me.inner.poll_write_vectored(cx, bufs) + } else { + bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); + Poll::Ready(Ok(total_len)) + } + } else { + // Remove empty buffers at the beginning of bufs. + while bufs.first().map(|buf| buf.len()) == Some(0) { + bufs = &bufs[1..]; + } + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + // Flush if the first buffer doesn't fit. + let first_len = bufs[0].len(); + if first_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + debug_assert!(self.buf.is_empty()); + } + let me = self.as_mut().project(); + if first_len >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + debug_assert!(me.buf.is_empty()); + return me.inner.poll_write(cx, &bufs[0]); + } else { + me.buf.extend_from_slice(&bufs[0]); + bufs = &bufs[1..]; + } + let mut total_written = first_len; + debug_assert!(total_written != 0); + // Append the buffers that fit in the internal buffer. + for buf in bufs { + if buf.len() > me.buf.capacity() - me.buf.len() { + break; + } else { + me.buf.extend_from_slice(buf); + total_written += buf.len(); + } + } + Poll::Ready(Ok(total_written)) + } + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { ready!(self.as_mut().flush_buf(cx))?; self.get_pin_mut().poll_flush(cx) diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index 3cd425b..fbd77b5 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -8,6 +8,7 @@ use std::task::{Context, Poll}; #[derive(Debug)] pub(super) struct CopyBuffer { read_done: bool, + need_flush: bool, pos: usize, cap: usize, amt: u64, @@ -18,6 +19,7 @@ impl CopyBuffer { pub(super) fn new() -> Self { Self { read_done: false, + need_flush: false, pos: 0, cap: 0, amt: 0, @@ -41,7 +43,22 @@ impl CopyBuffer { if self.pos == self.cap && !self.read_done { let me = &mut *self; let mut buf = ReadBuf::new(&mut me.buf); - ready!(reader.as_mut().poll_read(cx, &mut buf))?; + + match reader.as_mut().poll_read(cx, &mut buf) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + let n = buf.filled().len(); if n == 0 { self.read_done = true; @@ -63,6 +80,7 @@ impl CopyBuffer { } else { self.pos += i; self.amt += i as u64; + self.need_flush = true; } } diff --git a/src/io/util/fill_buf.rs b/src/io/util/fill_buf.rs new file mode 100644 index 0000000..98ae2ea --- /dev/null +++ b/src/io/util/fill_buf.rs @@ -0,0 +1,52 @@ +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`fill_buf`](crate::io::AsyncBufReadExt::fill_buf) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct FillBuf<'a, R: ?Sized> { + reader: Option<&'a mut R>, + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn fill_buf<R>(reader: &mut R) -> FillBuf<'_, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + FillBuf { + reader: Some(reader), + _pin: PhantomPinned, + } +} + +impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for FillBuf<'a, R> { + type Output = io::Result<&'a [u8]>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + // Due to a limitation in the borrow-checker, we cannot return the value + // directly on Ready. Once Rust starts using the polonius borrow checker, + // this can be simplified. + let reader = me.reader.take().expect("Polled after completion."); + match Pin::new(&mut *reader).poll_fill_buf(cx) { + Poll::Ready(_) => match Pin::new(reader).poll_fill_buf(cx) { + Poll::Ready(slice) => Poll::Ready(slice), + Poll::Pending => panic!("poll_fill_buf returned Pending while having data"), + }, + Poll::Pending => { + *me.reader = Some(reader); + Poll::Pending + } + } + } +} diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index d02a453..3fbf5e3 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -47,6 +47,10 @@ where { /// Returns the next line in the stream. /// + /// # Cancel safety + /// + /// This method is cancellation safe. + /// /// # Examples /// /// ``` @@ -102,11 +106,9 @@ where /// /// When the method returns `Poll::Pending`, the `Waker` in the provided /// `Context` is scheduled to receive a wakeup when more bytes become - /// available on the underlying IO resource. - /// - /// Note that on multiple calls to `poll_next_line`, only the `Waker` from - /// the `Context` passed to the most recent call is scheduled to receive a - /// wakeup. + /// available on the underlying IO resource. Note that on multiple calls to + /// `poll_next_line`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. pub fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs index fd3dd0d..21199d0 100644 --- a/src/io/util/mod.rs +++ b/src/io/util/mod.rs @@ -49,6 +49,7 @@ cfg_io_util! { mod read_exact; mod read_int; mod read_line; + mod fill_buf; mod read_to_end; mod vec_with_initialized; diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs index 5b9fb7b..164dcf5 100644 --- a/src/io/util/read_int.rs +++ b/src/io/util/read_int.rs @@ -142,6 +142,9 @@ reader!(ReadI32, i32, get_i32); reader!(ReadI64, i64, get_i64); reader!(ReadI128, i128, get_i128); +reader!(ReadF32, f32, get_f32); +reader!(ReadF64, f64, get_f64); + reader!(ReadU16Le, u16, get_u16_le); reader!(ReadU32Le, u32, get_u32_le); reader!(ReadU64Le, u64, get_u64_le); @@ -151,3 +154,6 @@ reader!(ReadI16Le, i16, get_i16_le); reader!(ReadI32Le, i32, get_i32_le); reader!(ReadI64Le, i64, get_i64_le); reader!(ReadI128Le, i128, get_i128_le); + +reader!(ReadF32Le, f32, get_f32_le); +reader!(ReadF64Le, f64, get_f64_le); diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs index 3599cff..90a0e8a 100644 --- a/src/io/util/read_until.rs +++ b/src/io/util/read_until.rs @@ -10,12 +10,12 @@ use std::task::{Context, Poll}; pin_project! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. - /// The delimeter is included in the resulting vector. + /// The delimiter is included in the resulting vector. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadUntil<'a, R: ?Sized> { reader: &'a mut R, - delimeter: u8, + delimiter: u8, buf: &'a mut Vec<u8>, // The number of bytes appended to buf. This can be less than buf.len() if // the buffer was not empty when the operation was started. @@ -28,7 +28,7 @@ pin_project! { pub(crate) fn read_until<'a, R>( reader: &'a mut R, - delimeter: u8, + delimiter: u8, buf: &'a mut Vec<u8>, ) -> ReadUntil<'a, R> where @@ -36,7 +36,7 @@ where { ReadUntil { reader, - delimeter, + delimiter, buf, read: 0, _pin: PhantomPinned, @@ -46,14 +46,14 @@ where pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( mut reader: Pin<&mut R>, cx: &mut Context<'_>, - delimeter: u8, + delimiter: u8, buf: &mut Vec<u8>, read: &mut usize, ) -> Poll<io::Result<usize>> { loop { let (done, used) = { let available = ready!(reader.as_mut().poll_fill_buf(cx))?; - if let Some(i) = memchr::memchr(delimeter, available) { + if let Some(i) = memchr::memchr(delimiter, available) { buf.extend_from_slice(&available[..=i]); (true, i + 1) } else { @@ -74,6 +74,6 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let me = self.project(); - read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read) + read_until_internal(Pin::new(*me.reader), cx, *me.delimiter, me.buf, me.read) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index 9c2bb05..7489c24 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -95,7 +95,7 @@ where let n = ready!(read_until_internal( me.reader, cx, *me.delim, me.buf, me.read, ))?; - // read_until_internal resets me.read to zero once it finds the delimeter + // read_until_internal resets me.read to zero once it finds the delimiter debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs index 13bc191..63cd491 100644 --- a/src/io/util/write_int.rs +++ b/src/io/util/write_int.rs @@ -135,6 +135,9 @@ writer!(WriteI32, i32, put_i32); writer!(WriteI64, i64, put_i64); writer!(WriteI128, i128, put_i128); +writer!(WriteF32, f32, put_f32); +writer!(WriteF64, f64, put_f64); + writer!(WriteU16Le, u16, put_u16_le); writer!(WriteU32Le, u32, put_u32_le); writer!(WriteU64Le, u64, put_u64_le); @@ -144,3 +147,6 @@ writer!(WriteI16Le, i16, put_i16_le); writer!(WriteI32Le, i32, put_i32_le); writer!(WriteI64Le, i64, put_i64_le); writer!(WriteI128Le, i128, put_i128_le); + +writer!(WriteF32Le, f32, put_f32_le); +writer!(WriteF64Le, f64, put_f64_le); @@ -10,7 +10,7 @@ unreachable_pub )] #![deny(unused_must_use)] -#![cfg_attr(docsrs, deny(broken_intra_doc_links))] +#![cfg_attr(docsrs, deny(rustdoc::broken_intra_doc_links))] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) diff --git a/src/loom/std/atomic_u64.rs b/src/loom/std/atomic_u64.rs index a86a195..8ea6bd4 100644 --- a/src/loom/std/atomic_u64.rs +++ b/src/loom/std/atomic_u64.rs @@ -2,21 +2,17 @@ //! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a //! `Mutex`. -pub(crate) use self::imp::AtomicU64; - // `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater. // Once `cfg_target_has_atomic` feature is stable, we can replace it with // `#[cfg(target_has_atomic = "64")]`. // Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target -#[cfg(not(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc")))] -mod imp { +cfg_has_atomic_u64! { pub(crate) use std::sync::atomic::AtomicU64; } -#[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] -mod imp { +cfg_not_has_atomic_u64! { + use crate::loom::sync::Mutex; use std::sync::atomic::Ordering; - use std::sync::Mutex; #[derive(Debug)] pub(crate) struct AtomicU64 { @@ -31,15 +27,15 @@ mod imp { } pub(crate) fn load(&self, _: Ordering) -> u64 { - *self.inner.lock().unwrap() + *self.inner.lock() } pub(crate) fn store(&self, val: u64, _: Ordering) { - *self.inner.lock().unwrap() = val; + *self.inner.lock() = val; } pub(crate) fn fetch_or(&self, val: u64, _: Ordering) -> u64 { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); let prev = *lock; *lock = prev | val; prev @@ -52,7 +48,7 @@ mod imp { _success: Ordering, _failure: Ordering, ) -> Result<u64, u64> { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); if *lock == current { *lock = new; diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs index 1e77556..7c87522 100644 --- a/src/macros/cfg.rs +++ b/src/macros/cfg.rs @@ -185,7 +185,7 @@ macro_rules! cfg_net_unix { macro_rules! cfg_net_windows { ($($item:item)*) => { $( - #[cfg(all(any(docsrs, windows), feature = "net"))] + #[cfg(all(any(all(doc, docsrs), windows), feature = "net"))] #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "net"))))] $item )* @@ -384,3 +384,29 @@ macro_rules! cfg_not_coop { )* } } + +macro_rules! cfg_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(not(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc" + )))] + $item + )* + } +} + +macro_rules! cfg_not_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc" + ))] + $item + )* + } +} diff --git a/src/macros/select.rs b/src/macros/select.rs index 371a3de..a90ee9e 100644 --- a/src/macros/select.rs +++ b/src/macros/select.rs @@ -23,10 +23,10 @@ /// returns the result of evaluating the completed branch's `<handler>` /// expression. /// -/// Additionally, each branch may include an optional `if` precondition. This -/// precondition is evaluated **before** the `<async expression>`. If the -/// precondition returns `false`, the branch is entirely disabled. This -/// capability is useful when using `select!` within a loop. +/// Additionally, each branch may include an optional `if` precondition. If the +/// precondition returns `false`, then the branch is disabled. The provided +/// `<async expression>` is still evaluated but the resulting future is never +/// polled. This capability is useful when using `select!` within a loop. /// /// The complete lifecycle of a `select!` expression is as follows: /// @@ -42,12 +42,10 @@ /// to the provided `<pattern>`, if the pattern matches, evaluate `<handler>` /// and return. If the pattern **does not** match, disable the current branch /// and for the remainder of the current call to `select!`. Continue from step 3. -/// 5. If **all** branches are disabled, evaluate the `else` expression. If none -/// is provided, panic. +/// 5. If **all** branches are disabled, evaluate the `else` expression. If no +/// else branch is provided, panic. /// -/// # Notes -/// -/// ### Runtime characteristics +/// # Runtime characteristics /// /// By running all async expressions on the current task, the expressions are /// able to run **concurrently** but not in **parallel**. This means all @@ -58,76 +56,7 @@ /// /// [`tokio::spawn`]: crate::spawn /// -/// ### Avoid racy `if` preconditions -/// -/// Given that `if` preconditions are used to disable `select!` branches, some -/// caution must be used to avoid missing values. -/// -/// For example, here is **incorrect** usage of `sleep` with `if`. The objective -/// is to repeatedly run an asynchronous task for up to 50 milliseconds. -/// However, there is a potential for the `sleep` completion to be missed. -/// -/// ```no_run -/// use tokio::time::{self, Duration}; -/// -/// async fn some_async_work() { -/// // do work -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let sleep = time::sleep(Duration::from_millis(50)); -/// tokio::pin!(sleep); -/// -/// while !sleep.is_elapsed() { -/// tokio::select! { -/// _ = &mut sleep, if !sleep.is_elapsed() => { -/// println!("operation timed out"); -/// } -/// _ = some_async_work() => { -/// println!("operation completed"); -/// } -/// } -/// } -/// } -/// ``` -/// -/// In the above example, `sleep.is_elapsed()` may return `true` even if -/// `sleep.poll()` never returned `Ready`. This opens up a potential race -/// condition where `sleep` expires between the `while !sleep.is_elapsed()` -/// check and the call to `select!` resulting in the `some_async_work()` call to -/// run uninterrupted despite the sleep having elapsed. -/// -/// One way to write the above example without the race would be: -/// -/// ``` -/// use tokio::time::{self, Duration}; -/// -/// async fn some_async_work() { -/// # time::sleep(Duration::from_millis(10)).await; -/// // do work -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let sleep = time::sleep(Duration::from_millis(50)); -/// tokio::pin!(sleep); -/// -/// loop { -/// tokio::select! { -/// _ = &mut sleep => { -/// println!("operation timed out"); -/// break; -/// } -/// _ = some_async_work() => { -/// println!("operation completed"); -/// } -/// } -/// } -/// } -/// ``` -/// -/// ### Fairness +/// # Fairness /// /// By default, `select!` randomly picks a branch to check first. This provides /// some level of fairness when calling `select!` in a loop with branches that @@ -151,10 +80,60 @@ /// /// # Panics /// -/// `select!` panics if all branches are disabled **and** there is no provided -/// `else` branch. A branch is disabled when the provided `if` precondition -/// returns `false` **or** when the pattern does not match the result of `<async -/// expression>`. +/// The `select!` macro panics if all branches are disabled **and** there is no +/// provided `else` branch. A branch is disabled when the provided `if` +/// precondition returns `false` **or** when the pattern does not match the +/// result of `<async expression>`. +/// +/// # Cancellation safety +/// +/// When using `select!` in a loop to receive messages from multiple sources, +/// you should make sure that the receive call is cancellation safe to avoid +/// losing messages. This section goes through various common methods and +/// describes whether they are cancel safe. The lists in this section are not +/// exhaustive. +/// +/// The following methods are cancellation safe: +/// +/// * [`tokio::sync::mpsc::Receiver::recv`](crate::sync::mpsc::Receiver::recv) +/// * [`tokio::sync::mpsc::UnboundedReceiver::recv`](crate::sync::mpsc::UnboundedReceiver::recv) +/// * [`tokio::sync::broadcast::Receiver::recv`](crate::sync::broadcast::Receiver::recv) +/// * [`tokio::sync::watch::Receiver::changed`](crate::sync::watch::Receiver::changed) +/// * [`tokio::net::TcpListener::accept`](crate::net::TcpListener::accept) +/// * [`tokio::net::UnixListener::accept`](crate::net::UnixListener::accept) +/// * [`tokio::io::AsyncReadExt::read`](crate::io::AsyncReadExt::read) on any `AsyncRead` +/// * [`tokio::io::AsyncReadExt::read_buf`](crate::io::AsyncReadExt::read_buf) on any `AsyncRead` +/// * [`tokio::io::AsyncWriteExt::write`](crate::io::AsyncWriteExt::write) on any `AsyncWrite` +/// * [`tokio::io::AsyncWriteExt::write_buf`](crate::io::AsyncWriteExt::write_buf) on any `AsyncWrite` +/// * [`tokio_stream::StreamExt::next`](https://docs.rs/tokio-stream/0.1/tokio_stream/trait.StreamExt.html#method.next) on any `Stream` +/// * [`futures::stream::StreamExt::next`](https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.next) on any `Stream` +/// +/// The following methods are not cancellation safe and can lead to loss of data: +/// +/// * [`tokio::io::AsyncReadExt::read_exact`](crate::io::AsyncReadExt::read_exact) +/// * [`tokio::io::AsyncReadExt::read_to_end`](crate::io::AsyncReadExt::read_to_end) +/// * [`tokio::io::AsyncReadExt::read_to_string`](crate::io::AsyncReadExt::read_to_string) +/// * [`tokio::io::AsyncWriteExt::write_all`](crate::io::AsyncWriteExt::write_all) +/// +/// The following methods are not cancellation safe because they use a queue for +/// fairness and cancellation makes you lose your place in the queue: +/// +/// * [`tokio::sync::Mutex::lock`](crate::sync::Mutex::lock) +/// * [`tokio::sync::RwLock::read`](crate::sync::RwLock::read) +/// * [`tokio::sync::RwLock::write`](crate::sync::RwLock::write) +/// * [`tokio::sync::Semaphore::acquire`](crate::sync::Semaphore::acquire) +/// * [`tokio::sync::Notify::notified`](crate::sync::Notify::notified) +/// +/// To determine whether your own methods are cancellation safe, look for the +/// location of uses of `.await`. This is because when an asynchronous method is +/// cancelled, that always happens at an `.await`. If your function behaves +/// correctly even if it is restarted while waiting at an `.await`, then it is +/// cancellation safe. +/// +/// Be aware that cancelling something that is not cancellation safe is not +/// necessarily wrong. For example, if you are cancelling a task because the +/// application is shutting down, then you probably don't care that partially +/// read data is lost. /// /// # Examples /// @@ -310,7 +289,7 @@ /// loop { /// tokio::select! { /// // If you run this example without `biased;`, the polling order is -/// // psuedo-random, and the assertions on the value of count will +/// // pseudo-random, and the assertions on the value of count will /// // (probably) fail. /// biased; /// @@ -338,6 +317,77 @@ /// } /// } /// ``` +/// +/// ## Avoid racy `if` preconditions +/// +/// Given that `if` preconditions are used to disable `select!` branches, some +/// caution must be used to avoid missing values. +/// +/// For example, here is **incorrect** usage of `sleep` with `if`. The objective +/// is to repeatedly run an asynchronous task for up to 50 milliseconds. +/// However, there is a potential for the `sleep` completion to be missed. +/// +/// ```no_run,should_panic +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// while !sleep.is_elapsed() { +/// tokio::select! { +/// _ = &mut sleep, if !sleep.is_elapsed() => { +/// println!("operation timed out"); +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// +/// panic!("This example shows how not to do it!"); +/// } +/// ``` +/// +/// In the above example, `sleep.is_elapsed()` may return `true` even if +/// `sleep.poll()` never returned `Ready`. This opens up a potential race +/// condition where `sleep` expires between the `while !sleep.is_elapsed()` +/// check and the call to `select!` resulting in the `some_async_work()` call to +/// run uninterrupted despite the sleep having elapsed. +/// +/// One way to write the above example without the race would be: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// # time::sleep(Duration::from_millis(10)).await; +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// loop { +/// tokio::select! { +/// _ = &mut sleep => { +/// println!("operation timed out"); +/// break; +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// } +/// ``` #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! select { @@ -417,7 +467,7 @@ macro_rules! select { let mut is_pending = false; // Choose a starting index to begin polling the futures at. In - // practice, this will either be a psuedo-randomly generrated + // practice, this will either be a pseudo-randomly generated // number by default, or the constant 0 if `biased;` is // supplied. let start = $start; diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 5c093bb..86f0ec1 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -125,6 +125,13 @@ impl TcpListener { /// established, the corresponding [`TcpStream`] and the remote peer's /// address will be returned. /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. + /// /// [`TcpStream`]: struct@crate::net::TcpStream /// /// # Examples diff --git a/src/net/tcp/split.rs b/src/net/tcp/split.rs index 78bd688..8ae70ce 100644 --- a/src/net/tcp/split.rs +++ b/src/net/tcp/split.rs @@ -30,7 +30,7 @@ pub struct ReadHalf<'a>(&'a TcpStream); /// Borrowed write half of a [`TcpStream`], created by [`split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will /// shut down the TCP stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found @@ -57,7 +57,7 @@ impl ReadHalf<'_> { /// `Waker` from the `Context` passed to the most recent call is scheduled /// to receive a wakeup. /// - /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// See the [`TcpStream::poll_peek`] level documentation for more details. /// /// # Examples /// @@ -95,7 +95,7 @@ impl ReadHalf<'_> { /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. /// - /// See the [`TcpStream::peek`] level documenation for more details. + /// See the [`TcpStream::peek`] level documentation for more details. /// /// [`TcpStream::peek`]: TcpStream::peek /// diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index d52c2f6..1bcb4f2 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -112,7 +112,7 @@ impl OwnedReadHalf { /// `Waker` from the `Context` passed to the most recent call is scheduled /// to receive a wakeup. /// - /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// See the [`TcpStream::poll_peek`] level documentation for more details. /// /// # Examples /// @@ -150,7 +150,7 @@ impl OwnedReadHalf { /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. /// - /// See the [`TcpStream::peek`] level documenation for more details. + /// See the [`TcpStream::peek`] level documentation for more details. /// /// [`TcpStream::peek`]: TcpStream::peek /// diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 2a367ef..34ac6ee 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -356,6 +356,13 @@ impl TcpStream { /// can be used to concurrently read / write to the same socket on a single /// task without splitting the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently read and write to the stream on the same task without @@ -420,6 +427,13 @@ impl TcpStream { /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_read()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -725,6 +739,13 @@ impl TcpStream { /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually /// paired with `try_write()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -915,6 +936,41 @@ impl TcpStream { .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(bufs)) } + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `TcpStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`writable()`]: TcpStream::writable() + /// [`ready()`]: TcpStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Receives data on the socket from the remote address to which it is /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. @@ -1152,6 +1208,12 @@ impl TcpStream { split_owned(self) } + // == Poll IO functions that takes `&self` == + // + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. + pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, diff --git a/src/net/udp.rs b/src/net/udp.rs index 6e63355..75cc6f3 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -327,6 +327,13 @@ impl UdpSocket { /// false-positive and attempting an operation will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently receive from and send to the socket on the same task @@ -390,6 +397,13 @@ impl UdpSocket { /// false-positive and attempting a `try_send()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -442,6 +456,12 @@ impl UdpSocket { /// On success, the number of bytes sent is returned, otherwise, the /// encountered error is returned. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// /// ```no_run @@ -559,6 +579,13 @@ impl UdpSocket { /// false-positive and attempting a `try_recv()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -613,6 +640,13 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. /// This method will fail if the socket is not connected. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// [`connect`]: method@Self::connect /// /// ```no_run @@ -665,7 +699,7 @@ impl UdpSocket { /// [`connect`]: method@Self::connect pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { let n = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -882,6 +916,12 @@ impl UdpSocket { /// /// [`ToSocketAddrs`]: crate::net::ToSocketAddrs /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Example /// /// ```no_run @@ -945,7 +985,7 @@ impl UdpSocket { /// /// # Returns /// - /// If successfull, returns the number of bytes sent + /// If successful, returns the number of bytes sent /// /// Users should ensure that when the remote cannot receive, the /// [`ErrorKind::WouldBlock`] is properly handled. An error can also occur @@ -1005,6 +1045,13 @@ impl UdpSocket { /// size to hold the message bytes. If a message is too long to fit in the /// supplied buffer, excess bytes may be discarded. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Example /// /// ```no_run @@ -1053,7 +1100,7 @@ impl UdpSocket { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1123,6 +1170,41 @@ impl UdpSocket { .try_io(Interest::READABLE, || self.io.recv_from(buf)) } + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UdpSocket` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UdpSocket::readable() + /// [`writable()`]: UdpSocket::writable() + /// [`ready()`]: UdpSocket::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Receives data from the socket, without removing it from the input queue. /// On success, returns the number of bytes read and the address from whence /// the data came. @@ -1192,7 +1274,7 @@ impl UdpSocket { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; diff --git a/src/net/unix/datagram/socket.rs b/src/net/unix/datagram/socket.rs index 6bc5615..7874b8a 100644 --- a/src/net/unix/datagram/socket.rs +++ b/src/net/unix/datagram/socket.rs @@ -106,6 +106,13 @@ impl UnixDatagram { /// false-positive and attempting an operation will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently receive from and send to the socket on the same task @@ -171,6 +178,13 @@ impl UnixDatagram { /// false-positive and attempting a `try_send()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -221,6 +235,13 @@ impl UnixDatagram { /// false-positive and attempting a `try_recv()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -490,6 +511,12 @@ impl UnixDatagram { /// Sends data on the socket to the socket's peer. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -613,6 +640,13 @@ impl UnixDatagram { /// Receives data from the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -820,6 +854,12 @@ impl UnixDatagram { /// Sends data on the socket to the specified address. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -863,6 +903,13 @@ impl UnixDatagram { /// Receives data from the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -927,7 +974,7 @@ impl UnixDatagram { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1028,7 +1075,7 @@ impl UnixDatagram { /// [`connect`]: method@Self::connect pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { let n = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1096,6 +1143,41 @@ impl UnixDatagram { Ok((n, SocketAddr(addr))) } + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixDatagram` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixDatagram::readable() + /// [`writable()`]: UnixDatagram::writable() + /// [`ready()`]: UnixDatagram::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Returns the local address that this socket is bound to. /// /// # Examples diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs index b5b05a6..efb9503 100644 --- a/src/net/unix/listener.rs +++ b/src/net/unix/listener.rs @@ -128,6 +128,13 @@ impl UnixListener { } /// Accepts a new incoming connection to this listener. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { let (mio, addr) = self .io diff --git a/src/net/unix/split.rs b/src/net/unix/split.rs index 24a711b..97214f7 100644 --- a/src/net/unix/split.rs +++ b/src/net/unix/split.rs @@ -29,7 +29,7 @@ pub struct ReadHalf<'a>(&'a UnixStream); /// Borrowed write half of a [`UnixStream`], created by [`split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will /// shut down the UnixStream stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs index 917844b..5837f36 100644 --- a/src/net/unix/stream.rs +++ b/src/net/unix/stream.rs @@ -51,6 +51,11 @@ impl UnixStream { let stream = UnixStream::new(stream)?; poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; + + if let Some(e) = stream.io.take_error()? { + return Err(e); + } + Ok(stream) } @@ -60,6 +65,13 @@ impl UnixStream { /// can be used to concurrently read / write to the same socket on a single /// task without splitting the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently read and write to the stream on the same task without @@ -126,6 +138,13 @@ impl UnixStream { /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_read()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -435,6 +454,13 @@ impl UnixStream { /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually /// paired with `try_write()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -627,6 +653,41 @@ impl UnixStream { .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) } + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`writable()`]: UnixStream::writable() + /// [`ready()`]: UnixStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`. /// /// This function is intended to be used to wrap a UnixStream from the @@ -826,14 +887,9 @@ impl AsyncWrite for UnixStream { impl UnixStream { // == Poll IO functions that takes `&self` == // - // They are not public because (taken from the doc of `PollEvented`): - // - // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the - // caller must ensure that there are at most two tasks that use a - // `PollEvented` instance concurrently. One for reading and one for writing. - // While violating this requirement is "safe" from a Rust memory model point - // of view, it will result in unexpected behavior in the form of lost - // notifications and tasks hanging. + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. pub(crate) fn poll_read_priv( &self, diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs index b95a8f6..49c7142 100644 --- a/src/net/unix/ucred.rs +++ b/src/net/unix/ucred.rs @@ -25,21 +25,19 @@ impl UCred { /// Gets PID (process ID) of the process. /// /// This is only implemented under Linux, Android, iOS, macOS, Solaris and - /// Illumos. On other plaforms this will always return `None`. + /// Illumos. On other platforms this will always return `None`. pub fn pid(&self) -> Option<pid_t> { self.pid } } -#[cfg(any(target_os = "linux", target_os = "android"))] +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] pub(crate) use self::impl_linux::get_peer_cred; -#[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd" -))] +#[cfg(any(target_os = "netbsd"))] +pub(crate) use self::impl_netbsd::get_peer_cred; + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] pub(crate) use self::impl_bsd::get_peer_cred; #[cfg(any(target_os = "macos", target_os = "ios"))] @@ -48,13 +46,16 @@ pub(crate) use self::impl_macos::get_peer_cred; #[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) use self::impl_solaris::get_peer_cred; -#[cfg(any(target_os = "linux", target_os = "android"))] +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] pub(crate) mod impl_linux { use crate::net::unix::UnixStream; use libc::{c_void, getsockopt, socklen_t, SOL_SOCKET, SO_PEERCRED}; use std::{io, mem}; + #[cfg(target_os = "openbsd")] + use libc::sockpeercred as ucred; + #[cfg(any(target_os = "linux", target_os = "android"))] use libc::ucred; pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { @@ -97,12 +98,49 @@ pub(crate) mod impl_linux { } } -#[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd" -))] +#[cfg(any(target_os = "netbsd"))] +pub(crate) mod impl_netbsd { + use crate::net::unix::UnixStream; + + use libc::{c_void, getsockopt, socklen_t, unpcbid, LOCAL_PEEREID, SOL_SOCKET}; + use std::io; + use std::mem::size_of; + use std::os::unix::io::AsRawFd; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut unpcbid = unpcbid { + unp_pid: 0, + unp_euid: 0, + unp_egid: 0, + }; + + let unpcbid_size = size_of::<unpcbid>(); + let mut unpcbid_size = unpcbid_size as socklen_t; + + let ret = getsockopt( + raw_fd, + SOL_SOCKET, + LOCAL_PEEREID, + &mut unpcbid as *mut unpcbid as *mut c_void, + &mut unpcbid_size, + ); + if ret == 0 && unpcbid_size as usize == size_of::<unpcbid>() { + Ok(super::UCred { + uid: unpcbid.unp_euid, + gid: unpcbid.unp_egid, + pid: Some(unpcbid.unp_pid), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] pub(crate) mod impl_bsd { use crate::net::unix::UnixStream; diff --git a/src/net/windows/named_pipe.rs b/src/net/windows/named_pipe.rs index 8013d6f..de6ab58 100644 --- a/src/net/windows/named_pipe.rs +++ b/src/net/windows/named_pipe.rs @@ -4,12 +4,12 @@ use std::ffi::c_void; use std::ffi::OsStr; -use std::io; +use std::io::{self, Read, Write}; use std::pin::Pin; use std::ptr; use std::task::{Context, Poll}; -use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; // Hide imports which are not used when generating documentation. @@ -163,8 +163,16 @@ impl NamedPipeServer { /// /// This corresponds to the [`ConnectNamedPipe`] system call. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as the + /// event in a [`select!`](crate::select) statement and some other branch + /// completes first, then no connection events have been lost. + /// /// [`ConnectNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-connectnamedpipe /// + /// # Example + /// /// ```no_run /// use tokio::net::windows::named_pipe::ServerOptions; /// @@ -225,6 +233,531 @@ impl NamedPipeServer { pub fn disconnect(&self) -> io::Result<()> { self.io.disconnect() } + + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// let ready = server.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Wait for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Try to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Try to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeServer::try_read() + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Wait for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Try to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Try to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeServer::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the + /// methods defined on the Tokio `NamedPipeServer` type, as this will mess with + /// the readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`writable()`]: NamedPipeServer::writable() + /// [`ready()`]: NamedPipeServer::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } } impl AsyncRead for NamedPipeServer { @@ -362,6 +895,524 @@ impl NamedPipeClient { // Safety: we're ensuring the lifetime of the named pipe. unsafe { named_pipe_info(self.io.as_raw_handle()) } } + + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// let ready = client.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Wait for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Try to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Try to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeClient::try_read() + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Wait for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Try to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Try to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeClient::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Try to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `NamedPipeClient` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`writable()`]: NamedPipeClient::writable() + /// [`ready()`]: NamedPipeClient::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } } impl AsyncRead for NamedPipeClient { @@ -1017,7 +2068,7 @@ impl ClientOptions { /// [enabled I/O]: crate::runtime::Builder::enable_io /// [Tokio Runtime]: crate::runtime::Runtime /// - /// A connect loop that waits until a socket becomes available looks like + /// A connect loop that waits until a pipe becomes available looks like /// this: /// /// ```no_run diff --git a/src/process/mod.rs b/src/process/mod.rs index 96ceb6d..42654b1 100644 --- a/src/process/mod.rs +++ b/src/process/mod.rs @@ -199,6 +199,8 @@ use std::io; #[cfg(unix)] use std::os::unix::process::CommandExt; #[cfg(windows)] +use std::os::windows::io::{AsRawHandle, RawHandle}; +#[cfg(windows)] use std::os::windows::process::CommandExt; use std::path::Path; use std::pin::Pin; @@ -551,6 +553,7 @@ impl Command { /// /// [1]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms684863(v=vs.85).aspx #[cfg(windows)] + #[cfg_attr(docsrs, doc(cfg(windows)))] pub fn creation_flags(&mut self, flags: u32) -> &mut Command { self.std.creation_flags(flags); self @@ -560,6 +563,7 @@ impl Command { /// `setuid` call in the child process. Failure in the `setuid` /// call will cause the spawn to fail. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn uid(&mut self, id: u32) -> &mut Command { self.std.uid(id); self @@ -568,11 +572,26 @@ impl Command { /// Similar to `uid` but sets the group ID of the child process. This has /// the same semantics as the `uid` field. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn gid(&mut self, id: u32) -> &mut Command { self.std.gid(id); self } + /// Set executable argument + /// + /// Set the first process argument, `argv[0]`, to something other than the + /// default executable path. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub fn arg0<S>(&mut self, arg: S) -> &mut Command + where + S: AsRef<OsStr>, + { + self.std.arg0(arg); + self + } + /// Schedules a closure to be run just before the `exec` function is /// invoked. /// @@ -603,6 +622,7 @@ impl Command { /// working directory have successfully been changed, so output to these /// locations may not appear where intended. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Command where F: FnMut() -> io::Result<()> + Send + Sync + 'static, @@ -934,6 +954,16 @@ impl Child { } } + /// Extracts the raw handle of the process associated with this child while + /// it is still running. Returns `None` if the child has exited. + #[cfg(windows)] + pub fn raw_handle(&self) -> Option<RawHandle> { + match &self.child { + FusedChild::Child(c) => Some(c.inner.as_raw_handle()), + FusedChild::Done(_) => None, + } + } + /// Attempts to force the child to exit, but does not wait for the request /// to take effect. /// diff --git a/src/process/unix/orphan.rs b/src/process/unix/orphan.rs index 07f0dcf..1b0022c 100644 --- a/src/process/unix/orphan.rs +++ b/src/process/unix/orphan.rs @@ -87,7 +87,7 @@ impl<T> OrphanQueueImpl<T> { // means that the signal driver isn't running, in // which case there isn't anything we can // register/initialize here, so we can try again later - if let Ok(sigchild) = signal_with_handle(SignalKind::child(), &handle) { + if let Ok(sigchild) = signal_with_handle(SignalKind::child(), handle) { *sigchild_guard = Some(sigchild); drain_orphan_queue(queue); } diff --git a/src/process/windows.rs b/src/process/windows.rs index 7237525..06fc1b6 100644 --- a/src/process/windows.rs +++ b/src/process/windows.rs @@ -24,7 +24,7 @@ use mio::windows::NamedPipe; use std::fmt; use std::future::Future; use std::io; -use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle}; +use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; use std::pin::Pin; use std::process::Stdio; use std::process::{Child as StdChild, Command as StdCommand, ExitStatus}; @@ -144,6 +144,12 @@ impl Future for Child { } } +impl AsRawHandle for Child { + fn as_raw_handle(&self) -> RawHandle { + self.child.as_raw_handle() + } +} + impl Drop for Waiting { fn drop(&mut self) { unsafe { diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs index 13dfb69..fe2e4a8 100644 --- a/src/runtime/basic_scheduler.rs +++ b/src/runtime/basic_scheduler.rs @@ -2,16 +2,14 @@ use crate::future::poll_fn; use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::Mutex; use crate::park::{Park, Unpark}; -use crate::runtime::task::{self, JoinHandle, Schedule, Task}; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; use crate::sync::notify::Notify; -use crate::util::linked_list::{Link, LinkedList}; use crate::util::{waker_ref, Wake, WakerRef}; use std::cell::RefCell; use std::collections::VecDeque; use std::fmt; use std::future::Future; -use std::ptr::NonNull; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::sync::Arc; use std::task::Poll::{Pending, Ready}; @@ -57,9 +55,6 @@ pub(crate) struct Spawner { } struct Tasks { - /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, - /// Local run queue. /// /// Tasks notified from the current thread are pushed into this queue. @@ -69,23 +64,23 @@ struct Tasks { /// A remote scheduler entry. /// /// These are filled in by remote threads sending instructions to the scheduler. -enum Entry { +enum RemoteMsg { /// A remote thread wants to spawn a task. Schedule(task::Notified<Arc<Shared>>), - /// A remote thread wants a task to be released by the scheduler. We only - /// have access to its header. - Release(NonNull<task::Header>), } // Safety: Used correctly, the task header is "thread safe". Ultimately the task // is owned by the current thread executor, for which this instruction is being // sent. -unsafe impl Send for Entry {} +unsafe impl Send for RemoteMsg {} /// Scheduler state shared between threads. struct Shared { /// Remote run queue. None if the `Runtime` has been dropped. - queue: Mutex<Option<VecDeque<Entry>>>, + queue: Mutex<Option<VecDeque<RemoteMsg>>>, + + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, /// Unpark the blocked thread. unpark: Box<dyn Unpark>, @@ -125,6 +120,7 @@ impl<P: Park> BasicScheduler<P> { let spawner = Spawner { shared: Arc::new(Shared { queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), + owned: OwnedTasks::new(), unpark: unpark as Box<dyn Unpark>, woken: AtomicBool::new(false), }), @@ -132,7 +128,6 @@ impl<P: Park> BasicScheduler<P> { let inner = Mutex::new(Some(Inner { tasks: Some(Tasks { - owned: LinkedList::new(), queue: VecDeque::with_capacity(INITIAL_CAPACITY), }), spawner: spawner.clone(), @@ -191,7 +186,7 @@ impl<P: Park> BasicScheduler<P> { Some(InnerGuard { inner: Some(inner), - basic_scheduler: &self, + basic_scheduler: self, }) } } @@ -227,7 +222,7 @@ impl<P: Park> Inner<P> { .borrow_mut() .queue .pop_front() - .map(Entry::Schedule) + .map(RemoteMsg::Schedule) }) } else { context @@ -235,7 +230,7 @@ impl<P: Park> Inner<P> { .borrow_mut() .queue .pop_front() - .map(Entry::Schedule) + .map(RemoteMsg::Schedule) .or_else(|| scheduler.spawner.pop()) }; @@ -251,25 +246,9 @@ impl<P: Park> Inner<P> { }; match entry { - Entry::Schedule(task) => crate::coop::budget(|| task.run()), - Entry::Release(ptr) => { - // Safety: the task header is only legally provided - // internally in the header, so we know that it is a - // valid (or in particular *allocated*) header that - // is part of the linked list. - unsafe { - let removed = context.tasks.borrow_mut().owned.remove(ptr); - - // TODO: This seems like it should hold, because - // there doesn't seem to be an avenue for anyone - // else to fiddle with the owned tasks - // collection *after* a remote thread has marked - // it as released, and at that point, the only - // location at which it can be removed is here - // or in the Drop implementation of the - // scheduler. - debug_assert!(removed.is_some()); - } + RemoteMsg::Schedule(task) => { + let task = context.shared.owned.assert_owner(task); + crate::coop::budget(|| task.run()) } } } @@ -335,47 +314,33 @@ impl<P: Park> Drop for BasicScheduler<P> { }; enter(&mut inner, |scheduler, context| { - // Loop required here to ensure borrow is dropped between iterations - #[allow(clippy::while_let_loop)] - loop { - let task = match context.tasks.borrow_mut().owned.pop_back() { - Some(task) => task, - None => break, - }; - - task.shutdown(); - } + // Drain the OwnedTasks collection. This call also closes the + // collection, ensuring that no tasks are ever pushed after this + // call returns. + context.shared.owned.close_and_shutdown_all(); // Drain local queue + // We already shut down every task, so we just need to drop the task. for task in context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + drop(task); } // Drain remote queue and set it to None - let mut remote_queue = scheduler.spawner.shared.queue.lock(); + let remote_queue = scheduler.spawner.shared.queue.lock().take(); // Using `Option::take` to replace the shared queue with `None`. - if let Some(remote_queue) = remote_queue.take() { + // We already shut down every task, so we just need to drop the task. + if let Some(remote_queue) = remote_queue { for entry in remote_queue { match entry { - Entry::Schedule(task) => { - task.shutdown(); - } - Entry::Release(..) => { - // Do nothing, each entry in the linked list was *just* - // dropped by the scheduler above. + RemoteMsg::Schedule(task) => { + drop(task); } } } } - // By dropping the mutex lock after the full duration of the above loop, - // any thread that sees the queue in the `None` state is guaranteed that - // the runtime has fully shut down. - // - // The assert below is unrelated to this mutex. - drop(remote_queue); - - assert!(context.tasks.borrow().owned.is_empty()); + + assert!(context.shared.owned.is_empty()); }); } } @@ -389,18 +354,22 @@ impl<P: Park> fmt::Debug for BasicScheduler<P> { // ===== impl Spawner ===== impl Spawner { - /// Spawns a future onto the thread pool + /// Spawns a future onto the basic scheduler pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - self.shared.schedule(task); + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + + if let Some(notified) = notified { + self.shared.schedule(notified); + } + handle } - fn pop(&self) -> Option<Entry> { + fn pop(&self) -> Option<RemoteMsg> { match self.shared.queue.lock().as_mut() { Some(queue) => queue.pop_front(), None => None, @@ -427,42 +396,8 @@ impl fmt::Debug for Spawner { // ===== impl Shared ===== impl Schedule for Arc<Shared> { - fn bind(task: Task<Self>) -> Arc<Shared> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.tasks.borrow_mut().owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - CURRENT.with(|maybe_cx| { - let ptr = NonNull::from(task.header()); - - if let Some(cx) = maybe_cx { - // safety: the task is inserted in the list in `bind`. - unsafe { cx.tasks.borrow_mut().owned.remove(ptr) } - } else { - // By sending an `Entry::Release` to the runtime, we ask the - // runtime to remove this task from the linked list in - // `Tasks::owned`. - // - // If the queue is `None`, then the task was already removed - // from that list in the destructor of `BasicScheduler`. We do - // not do anything in this case for the same reason that - // `Entry::Release` messages are ignored in the remote queue - // drain loop of `BasicScheduler`'s destructor. - if let Some(queue) = self.queue.lock().as_mut() { - queue.push_back(Entry::Release(ptr)); - } - - self.unpark.unpark(); - // Returning `None` here prevents the task plumbing from being - // freed. It is then up to the scheduler through the queue we - // just added to, or its Drop impl to free the task. - None - } - }) + self.owned.remove(task) } fn schedule(&self, task: task::Notified<Self>) { @@ -471,16 +406,13 @@ impl Schedule for Arc<Shared> { cx.tasks.borrow_mut().queue.push_back(task); } _ => { + // If the queue is None, then the runtime has shut down. We + // don't need to do anything with the notification in that case. let mut guard = self.queue.lock(); if let Some(queue) = guard.as_mut() { - queue.push_back(Entry::Schedule(task)); + queue.push_back(RemoteMsg::Schedule(task)); drop(guard); self.unpark.unpark(); - } else { - // The runtime has shut down. We drop the new task - // immediately. - drop(guard); - task.shutdown(); } } }); diff --git a/src/runtime/blocking/mod.rs b/src/runtime/blocking/mod.rs index fece3c2..670ec3a 100644 --- a/src/runtime/blocking/mod.rs +++ b/src/runtime/blocking/mod.rs @@ -8,7 +8,9 @@ pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; mod schedule; mod shutdown; -pub(crate) mod task; +mod task; +pub(crate) use schedule::NoopSchedule; +pub(crate) use task::BlockingTask; use crate::runtime::Builder; diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs index b7d7251..0c23bb0 100644 --- a/src/runtime/blocking/pool.rs +++ b/src/runtime/blocking/pool.rs @@ -71,7 +71,7 @@ struct Shared { worker_thread_index: usize, } -type Task = task::Notified<NoopSchedule>; +type Task = task::UnownedTask<NoopSchedule>; const KEEP_ALIVE: Duration = Duration::from_secs(10); diff --git a/src/runtime/blocking/schedule.rs b/src/runtime/blocking/schedule.rs index 4e044ab..5425224 100644 --- a/src/runtime/blocking/schedule.rs +++ b/src/runtime/blocking/schedule.rs @@ -9,11 +9,6 @@ use crate::runtime::task::{self, Task}; pub(crate) struct NoopSchedule; impl task::Schedule for NoopSchedule { - fn bind(_task: Task<Self>) -> NoopSchedule { - // Do nothing w/ the task - NoopSchedule - } - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { None } diff --git a/src/runtime/builder.rs b/src/runtime/builder.rs index 0249266..51bf8c8 100644 --- a/src/runtime/builder.rs +++ b/src/runtime/builder.rs @@ -413,7 +413,7 @@ impl Builder { /// Sets a custom timeout for a thread in the blocking pool. /// /// By default, the timeout for a thread is set to 10 seconds. This can - /// be overriden using .thread_keep_alive(). + /// be overridden using .thread_keep_alive(). /// /// # Example /// diff --git a/src/runtime/enter.rs b/src/runtime/enter.rs index 4dd8dd0..e91408f 100644 --- a/src/runtime/enter.rs +++ b/src/runtime/enter.rs @@ -64,7 +64,7 @@ cfg_rt! { // # Warning // // This is hidden for a reason. Do not use without fully understanding -// executors. Misuing can easily cause your program to deadlock. +// executors. Misusing can easily cause your program to deadlock. cfg_rt_multi_thread! { pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { // Reset in case the closure panics diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs index 173f0ca..ddc170a 100644 --- a/src/runtime/handle.rs +++ b/src/runtime/handle.rs @@ -1,4 +1,4 @@ -use crate::runtime::blocking::task::BlockingTask; +use crate::runtime::blocking::{BlockingTask, NoopSchedule}; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{blocking, context, driver, Spawner}; use crate::util::error::CONTEXT_MISSING_ERROR; @@ -145,7 +145,7 @@ impl Handle { F::Output: Send + 'static, { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "task"); + let future = crate::util::trace::task(future, "task", None); self.spawner.spawn(future) } @@ -174,6 +174,15 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + self.spawn_blocking_inner(func, None) + } + + #[cfg_attr(tokio_track_caller, track_caller)] + pub(crate) fn spawn_blocking_inner<F, R>(&self, func: F, name: Option<&str>) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { let fut = BlockingTask::new(func); #[cfg(all(tokio_unstable, feature = "tracing"))] @@ -187,6 +196,7 @@ impl Handle { "task", kind = %"blocking", function = %std::any::type_name::<F>(), + task.name = %name.unwrap_or_default(), spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), ); #[cfg(not(tokio_track_caller))] @@ -194,12 +204,17 @@ impl Handle { target: "tokio::task", "task", kind = %"blocking", + task.name = %name.unwrap_or_default(), function = %std::any::type_name::<F>(), ); fut.instrument(span) }; - let (task, handle) = task::joinable(fut); - let _ = self.blocking_spawner.spawn(task, &self); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let _ = name; + + let (task, handle) = task::unowned(fut, NoopSchedule); + let _ = self.blocking_spawner.spawn(task, self); handle } diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs index 3df7bba..c45cb6a 100644 --- a/src/runtime/queue.rs +++ b/src/runtime/queue.rs @@ -1,13 +1,12 @@ //! Run-queue structures to support a work-stealing scheduler use crate::loom::cell::UnsafeCell; -use crate::loom::sync::atomic::{AtomicU16, AtomicU32, AtomicUsize}; -use crate::loom::sync::{Arc, Mutex}; -use crate::runtime::task; +use crate::loom::sync::atomic::{AtomicU16, AtomicU32}; +use crate::loom::sync::Arc; +use crate::runtime::task::{self, Inject}; -use std::marker::PhantomData; use std::mem::MaybeUninit; -use std::ptr::{self, NonNull}; +use std::ptr; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; /// Producer handle. May only be used from a single thread. @@ -18,19 +17,6 @@ pub(super) struct Local<T: 'static> { /// Consumer handle. May be used from many threads. pub(super) struct Steal<T: 'static>(Arc<Inner<T>>); -/// Growable, MPMC queue used to inject new tasks into the scheduler and as an -/// overflow queue when the local, fixed-size, array queue overflows. -pub(super) struct Inject<T: 'static> { - /// Pointers to the head and tail of the queue - pointers: Mutex<Pointers>, - - /// Number of pending tasks in the queue. This helps prevent unnecessary - /// locking in the hot path. - len: AtomicUsize, - - _p: PhantomData<T>, -} - pub(super) struct Inner<T: 'static> { /// Concurrently updated by many threads. /// @@ -49,24 +35,11 @@ pub(super) struct Inner<T: 'static> { tail: AtomicU16, /// Elements - buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>]>, -} - -struct Pointers { - /// True if the queue is closed - is_closed: bool, - - /// Linked-list head - head: Option<NonNull<task::Header>>, - - /// Linked-list tail - tail: Option<NonNull<task::Header>>, + buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>, } unsafe impl<T> Send for Inner<T> {} unsafe impl<T> Sync for Inner<T> {} -unsafe impl<T> Send for Inject<T> {} -unsafe impl<T> Sync for Inject<T> {} #[cfg(not(loom))] const LOCAL_QUEUE_CAPACITY: usize = 256; @@ -79,6 +52,17 @@ const LOCAL_QUEUE_CAPACITY: usize = 4; const MASK: usize = LOCAL_QUEUE_CAPACITY - 1; +// Constructing the fixed size array directly is very awkward. The only way to +// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as +// the contents are not Copy. The trick with defining a const doesn't work for +// generic types. +fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { + assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY); + + // safety: We check that the length is correct. + unsafe { Box::from_raw(Box::into_raw(buffer).cast()) } +} + /// Create a new local run-queue pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); @@ -90,7 +74,7 @@ pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { let inner = Arc::new(Inner { head: AtomicU32::new(0), tail: AtomicU16::new(0), - buffer: buffer.into(), + buffer: make_fixed_size(buffer.into_boxed_slice()), }); let local = Local { @@ -109,10 +93,7 @@ impl<T> Local<T> { } /// Pushes a task to the back of the local queue, skipping the LIFO slot. - pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) - where - T: crate::runtime::task::Schedule, - { + pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) { let tail = loop { let head = self.inner.head.load(Acquire); let (steal, real) = unpack(head); @@ -125,13 +106,8 @@ impl<T> Local<T> { break tail; } else if steal != real { // Concurrently stealing, this will free up capacity, so only - // push the new task onto the inject queue - // - // If the task failes to be pushed on the injection queue, there - // is nothing to be done at this point as the task cannot be a - // newly spawned task. Shutting down this task is handled by the - // worker shutdown process. - let _ = inject.push(task); + // push the task onto the inject queue + inject.push(task); return; } else { // Push the current task and half of the queue into the @@ -179,9 +155,12 @@ impl<T> Local<T> { tail: u16, inject: &Inject<T>, ) -> Result<(), task::Notified<T>> { - const BATCH_LEN: usize = LOCAL_QUEUE_CAPACITY / 2 + 1; + /// How many elements are we taking from the local queue. + /// + /// This is one less than the number of tasks pushed to the inject + /// queue as we are also inserting the `task` argument. + const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16; - let n = (LOCAL_QUEUE_CAPACITY / 2) as u16; assert_eq!( tail.wrapping_sub(head) as usize, LOCAL_QUEUE_CAPACITY, @@ -207,7 +186,10 @@ impl<T> Local<T> { .head .compare_exchange( prev, - pack(head.wrapping_add(n), head.wrapping_add(n)), + pack( + head.wrapping_add(NUM_TASKS_TAKEN), + head.wrapping_add(NUM_TASKS_TAKEN), + ), Release, Relaxed, ) @@ -219,41 +201,41 @@ impl<T> Local<T> { return Err(task); } - // link the tasks - for i in 0..n { - let j = i + 1; - - let i_idx = i.wrapping_add(head) as usize & MASK; - let j_idx = j.wrapping_add(head) as usize & MASK; - - // Get the next pointer - let next = if j == n { - // The last task in the local queue being moved - task.header().into() - } else { - // safety: The above CAS prevents a stealer from accessing these - // tasks and we are the only producer. - self.inner.buffer[j_idx].with(|ptr| unsafe { - let value = (*ptr).as_ptr(); - (*value).header().into() - }) - }; - - // safety: the above CAS prevents a stealer from accessing these - // tasks and we are the only producer. - self.inner.buffer[i_idx].with_mut(|ptr| unsafe { - let ptr = (*ptr).as_ptr(); - (*ptr).header().set_next(Some(next)) - }); + /// An iterator that takes elements out of the run queue. + struct BatchTaskIter<'a, T: 'static> { + buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY], + head: u32, + i: u32, + } + impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> { + type Item = task::Notified<T>; + + #[inline] + fn next(&mut self) -> Option<task::Notified<T>> { + if self.i == u32::from(NUM_TASKS_TAKEN) { + None + } else { + let i_idx = self.i.wrapping_add(self.head) as usize & MASK; + let slot = &self.buffer[i_idx]; + + // safety: Our CAS from before has assumed exclusive ownership + // of the task pointers in this range. + let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + self.i += 1; + Some(task) + } + } } - // safety: the above CAS prevents a stealer from accessing these tasks - // and we are the only producer. - let head = self.inner.buffer[head as usize & MASK] - .with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); - - // Push the tasks onto the inject queue - inject.push_batch(head, task, BATCH_LEN); + // safety: The CAS above ensures that no consumer will look at these + // values again, and we are the only producer. + let batch_iter = BatchTaskIter { + buffer: &*self.inner.buffer, + head: head as u32, + i: 0, + }; + inject.push_batch(batch_iter.chain(std::iter::once(task))); Ok(()) } @@ -473,159 +455,6 @@ impl<T> Inner<T> { } } -impl<T: 'static> Inject<T> { - pub(super) fn new() -> Inject<T> { - Inject { - pointers: Mutex::new(Pointers { - is_closed: false, - head: None, - tail: None, - }), - len: AtomicUsize::new(0), - _p: PhantomData, - } - } - - pub(super) fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Close the injection queue, returns `true` if the queue is open when the - /// transition is made. - pub(super) fn close(&self) -> bool { - let mut p = self.pointers.lock(); - - if p.is_closed { - return false; - } - - p.is_closed = true; - true - } - - pub(super) fn is_closed(&self) -> bool { - self.pointers.lock().is_closed - } - - pub(super) fn len(&self) -> usize { - self.len.load(Acquire) - } - - /// Pushes a value into the queue. - /// - /// Returns `Err(task)` if pushing fails due to the queue being shutdown. - /// The caller is expected to call `shutdown()` on the task **if and only - /// if** it is a newly spawned task. - pub(super) fn push(&self, task: task::Notified<T>) -> Result<(), task::Notified<T>> - where - T: crate::runtime::task::Schedule, - { - // Acquire queue lock - let mut p = self.pointers.lock(); - - if p.is_closed { - return Err(task); - } - - // safety: only mutated with the lock held - let len = unsafe { self.len.unsync_load() }; - let task = task.into_raw(); - - // The next pointer should already be null - debug_assert!(get_next(task).is_none()); - - if let Some(tail) = p.tail { - set_next(tail, Some(task)); - } else { - p.head = Some(task); - } - - p.tail = Some(task); - - self.len.store(len + 1, Release); - Ok(()) - } - - pub(super) fn push_batch( - &self, - batch_head: task::Notified<T>, - batch_tail: task::Notified<T>, - num: usize, - ) { - let batch_head = batch_head.into_raw(); - let batch_tail = batch_tail.into_raw(); - - debug_assert!(get_next(batch_tail).is_none()); - - let mut p = self.pointers.lock(); - - if let Some(tail) = p.tail { - set_next(tail, Some(batch_head)); - } else { - p.head = Some(batch_head); - } - - p.tail = Some(batch_tail); - - // Increment the count. - // - // safety: All updates to the len atomic are guarded by the mutex. As - // such, a non-atomic load followed by a store is safe. - let len = unsafe { self.len.unsync_load() }; - - self.len.store(len + num, Release); - } - - pub(super) fn pop(&self) -> Option<task::Notified<T>> { - // Fast path, if len == 0, then there are no values - if self.is_empty() { - return None; - } - - let mut p = self.pointers.lock(); - - // It is possible to hit null here if another thread popped the last - // task between us checking `len` and acquiring the lock. - let task = p.head?; - - p.head = get_next(task); - - if p.head.is_none() { - p.tail = None; - } - - set_next(task, None); - - // Decrement the count. - // - // safety: All updates to the len atomic are guarded by the mutex. As - // such, a non-atomic load followed by a store is safe. - self.len - .store(unsafe { self.len.unsync_load() } - 1, Release); - - // safety: a `Notified` is pushed into the queue and now it is popped! - Some(unsafe { task::Notified::from_raw(task) }) - } -} - -impl<T: 'static> Drop for Inject<T> { - fn drop(&mut self) { - if !std::thread::panicking() { - assert!(self.pop().is_none(), "queue not empty"); - } - } -} - -fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { - unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } -} - -fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { - unsafe { - header.as_ref().set_next(val); - } -} - /// Split the head value into the real head and the index a stealer is working /// on. fn unpack(n: u32) -> (u16, u16) { diff --git a/src/runtime/shell.rs b/src/runtime/shell.rs deleted file mode 100644 index 486d4fa..0000000 --- a/src/runtime/shell.rs +++ /dev/null @@ -1,132 +0,0 @@ -#![allow(clippy::redundant_clone)] - -use crate::future::poll_fn; -use crate::park::{Park, Unpark}; -use crate::runtime::driver::Driver; -use crate::sync::Notify; -use crate::util::{waker_ref, Wake}; - -use std::sync::{Arc, Mutex}; -use std::task::Context; -use std::task::Poll::{Pending, Ready}; -use std::{future::Future, sync::PoisonError}; - -#[derive(Debug)] -pub(super) struct Shell { - driver: Mutex<Option<Driver>>, - - notify: Notify, - - /// TODO: don't store this - unpark: Arc<Handle>, -} - -#[derive(Debug)] -struct Handle(<Driver as Park>::Unpark); - -impl Shell { - pub(super) fn new(driver: Driver) -> Shell { - let unpark = Arc::new(Handle(driver.unpark())); - - Shell { - driver: Mutex::new(Some(driver)), - notify: Notify::new(), - unpark, - } - } - - pub(super) fn block_on<F>(&self, f: F) -> F::Output - where - F: Future, - { - let mut enter = crate::runtime::enter(true); - - pin!(f); - - loop { - if let Some(driver) = &mut self.take_driver() { - return driver.block_on(f); - } else { - let notified = self.notify.notified(); - pin!(notified); - - if let Some(out) = enter - .block_on(poll_fn(|cx| { - if notified.as_mut().poll(cx).is_ready() { - return Ready(None); - } - - if let Ready(out) = f.as_mut().poll(cx) { - return Ready(Some(out)); - } - - Pending - })) - .expect("Failed to `Enter::block_on`") - { - return out; - } - } - } - } - - fn take_driver(&self) -> Option<DriverGuard<'_>> { - let mut lock = self.driver.lock().unwrap(); - let driver = lock.take()?; - - Some(DriverGuard { - inner: Some(driver), - shell: &self, - }) - } -} - -impl Wake for Handle { - /// Wake by value - fn wake(self: Arc<Self>) { - Wake::wake_by_ref(&self); - } - - /// Wake by reference - fn wake_by_ref(arc_self: &Arc<Self>) { - arc_self.0.unpark(); - } -} - -struct DriverGuard<'a> { - inner: Option<Driver>, - shell: &'a Shell, -} - -impl DriverGuard<'_> { - fn block_on<F: Future>(&mut self, f: F) -> F::Output { - let driver = self.inner.as_mut().unwrap(); - - pin!(f); - - let waker = waker_ref(&self.shell.unpark); - let mut cx = Context::from_waker(&waker); - - loop { - if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { - return v; - } - - driver.park().unwrap(); - } - } -} - -impl Drop for DriverGuard<'_> { - fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - self.shell - .driver - .lock() - .unwrap_or_else(PoisonError::into_inner) - .replace(inner); - - self.shell.notify.notify_one(); - } - } -} diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs index 026a6dc..51b6496 100644 --- a/src/runtime/task/core.rs +++ b/src/runtime/task/core.rs @@ -13,7 +13,7 @@ use crate::future::Future; use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::{Notified, Schedule, Task}; +use crate::runtime::task::Schedule; use crate::util::linked_list; use std::pin::Pin; @@ -36,10 +36,6 @@ pub(super) struct Cell<T: Future, S> { pub(super) trailer: Trailer, } -pub(super) struct Scheduler<S> { - scheduler: UnsafeCell<Option<S>>, -} - pub(super) struct CoreStage<T: Future> { stage: UnsafeCell<Stage<T>>, } @@ -49,7 +45,7 @@ pub(super) struct CoreStage<T: Future> { /// Holds the future or output, depending on the stage of execution. pub(super) struct Core<T: Future, S> { /// Scheduler used to drive this future - pub(super) scheduler: Scheduler<S>, + pub(super) scheduler: S, /// Either the future or the output pub(super) stage: CoreStage<T>, @@ -61,17 +57,27 @@ pub(crate) struct Header { /// Task state pub(super) state: State, - pub(crate) owned: UnsafeCell<linked_list::Pointers<Header>>, + pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, /// Pointer to next task, used with the injection queue - pub(crate) queue_next: UnsafeCell<Option<NonNull<Header>>>, - - /// Pointer to the next task in the transfer stack - pub(super) stack_next: UnsafeCell<Option<NonNull<Header>>>, + pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, /// Table of function pointers for executing actions on the task. pub(super) vtable: &'static Vtable, + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell<u64>, + /// The tracing ID for this instrumented task. #[cfg(all(tokio_unstable, feature = "tracing"))] pub(super) id: Option<tracing::Id>, @@ -96,7 +102,7 @@ pub(super) enum Stage<T: Future> { impl<T: Future, S: Schedule> Cell<T, S> { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, state: State) -> Box<Cell<T, S>> { + pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { #[cfg(all(tokio_unstable, feature = "tracing"))] let id = future.id(); Box::new(Cell { @@ -104,15 +110,13 @@ impl<T: Future, S: Schedule> Cell<T, S> { state, owned: UnsafeCell::new(linked_list::Pointers::new()), queue_next: UnsafeCell::new(None), - stack_next: UnsafeCell::new(None), vtable: raw::vtable::<T, S>(), + owner_id: UnsafeCell::new(0), #[cfg(all(tokio_unstable, feature = "tracing"))] id, }, core: Core { - scheduler: Scheduler { - scheduler: UnsafeCell::new(None), - }, + scheduler, stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), }, @@ -124,92 +128,6 @@ impl<T: Future, S: Schedule> Cell<T, S> { } } -impl<S: Schedule> Scheduler<S> { - pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Option<S>) -> R) -> R { - self.scheduler.with_mut(f) - } - - /// Bind a scheduler to the task. - /// - /// This only happens on the first poll and must be preceeded by a call to - /// `is_bound` to determine if binding is appropriate or not. - /// - /// # Safety - /// - /// Binding must not be done concurrently since it will mutate the task - /// core through a shared reference. - pub(super) fn bind_scheduler(&self, task: Task<S>) { - // This function may be called concurrently, but the __first__ time it - // is called, the caller has unique access to this field. All subsequent - // concurrent calls will be via the `Waker`, which will "happens after" - // the first poll. - // - // In other words, it is always safe to read the field and it is safe to - // write to the field when it is `None`. - debug_assert!(!self.is_bound()); - - // Bind the task to the scheduler - let scheduler = S::bind(task); - - // Safety: As `scheduler` is not set, this is the first poll - self.scheduler.with_mut(|ptr| unsafe { - *ptr = Some(scheduler); - }); - } - - /// Returns true if the task is bound to a scheduler. - pub(super) fn is_bound(&self) -> bool { - // Safety: never called concurrently w/ a mutation. - self.scheduler.with(|ptr| unsafe { (*ptr).is_some() }) - } - - /// Schedule the future for execution - pub(super) fn schedule(&self, task: Notified<S>) { - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.schedule(task), - None => panic!("no scheduler set"), - } - }); - } - - /// Schedule the future for execution in the near future, yielding the - /// thread to other tasks. - pub(super) fn yield_now(&self, task: Notified<S>) { - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.yield_now(task), - None => panic!("no scheduler set"), - } - }); - } - - /// Release the task - /// - /// If the `Scheduler` implementation is able to, it returns the `Task` - /// handle immediately. The caller of this function will batch a ref-dec - /// with a state change. - pub(super) fn release(&self, task: Task<S>) -> Option<Task<S>> { - use std::mem::ManuallyDrop; - - let task = ManuallyDrop::new(task); - - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.release(&*task), - // Task was never polled - None => None, - } - }) - } -} - impl<T: Future> CoreStage<T> { pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { self.stage.with_mut(f) @@ -220,7 +138,7 @@ impl<T: Future> CoreStage<T> { /// # Safety /// /// The caller must ensure it is safe to mutate the `state` field. This - /// requires ensuring mutal exclusion between any concurrent thread that + /// requires ensuring mutual exclusion between any concurrent thread that /// might modify the future or output field. /// /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` @@ -284,7 +202,7 @@ impl<T: Future> CoreStage<T> { use std::mem; self.stage.with_mut(|ptr| { - // Safety:: the caller ensures mutal exclusion to the field. + // Safety:: the caller ensures mutual exclusion to the field. match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { Stage::Finished(output) => output, _ => panic!("JoinHandle polled after completion"), @@ -299,32 +217,40 @@ impl<T: Future> CoreStage<T> { cfg_rt_multi_thread! { impl Header { - pub(crate) fn shutdown(&self) { - use crate::runtime::task::RawTask; - - let task = unsafe { RawTask::from_raw(self.into()) }; - task.shutdown(); - } - - pub(crate) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { + pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { self.queue_next.with_mut(|ptr| *ptr = next); } } } +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + impl Trailer { - pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) { + pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { self.waker.with_mut(|ptr| { *ptr = waker; }); } - pub(crate) unsafe fn will_wake(&self, waker: &Waker) -> bool { + pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { self.waker .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) } - pub(crate) fn wake_join(&self) { + pub(super) fn wake_join(&self) { self.waker.with(|ptr| match unsafe { &*ptr } { Some(waker) => waker.wake_by_ref(), None => panic!("waker missing"), diff --git a/src/runtime/task/error.rs b/src/runtime/task/error.rs index 177fe65..17fb093 100644 --- a/src/runtime/task/error.rs +++ b/src/runtime/task/error.rs @@ -1,7 +1,8 @@ use std::any::Any; use std::fmt; use std::io; -use std::sync::Mutex; + +use crate::util::SyncWrapper; cfg_rt! { /// Task failed to execute to completion. @@ -12,7 +13,7 @@ cfg_rt! { enum Repr { Cancelled, - Panic(Mutex<Box<dyn Any + Send + 'static>>), + Panic(SyncWrapper<Box<dyn Any + Send + 'static>>), } impl JoinError { @@ -24,7 +25,7 @@ impl JoinError { pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { JoinError { - repr: Repr::Panic(Mutex::new(err)), + repr: Repr::Panic(SyncWrapper::new(err)), } } @@ -106,7 +107,7 @@ impl JoinError { /// ``` pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> { match self.repr { - Repr::Panic(p) => Ok(p.into_inner().expect("Extracting panic from mutex")), + Repr::Panic(p) => Ok(p.into_inner()), _ => Err(self), } } diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs index 47bbcc1..41b4193 100644 --- a/src/runtime/task/harness.rs +++ b/src/runtime/task/harness.rs @@ -1,10 +1,11 @@ use crate::future::Future; -use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer}; +use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; use crate::runtime::task::state::Snapshot; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{JoinError, Notified, Schedule, Task}; use std::mem; +use std::mem::ManuallyDrop; use std::panic; use std::ptr::NonNull; use std::task::{Context, Poll, Waker}; @@ -36,13 +37,6 @@ where fn core(&self) -> &Core<T, S> { unsafe { &self.cell.as_ref().core } } - - fn scheduler_view(&self) -> SchedulerView<'_, S> { - SchedulerView { - header: self.header(), - scheduler: &self.core().scheduler, - } - } } impl<T, S> Harness<T, S> @@ -50,43 +44,103 @@ where T: Future, S: Schedule, { - /// Polls the inner future. + /// Polls the inner future. A ref-count is consumed. /// /// All necessary state checks and transitions are performed. - /// /// Panics raised while polling the future are handled. pub(super) fn poll(self) { + // We pass our ref-count to `poll_inner`. match self.poll_inner() { PollFuture::Notified => { - // Signal yield - self.core().scheduler.yield_now(Notified(self.to_task())); - // The ref-count was incremented as part of - // `transition_to_idle`. + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + self.core() + .scheduler + .yield_now(Notified(self.get_new_task())); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. self.drop_reference(); } - PollFuture::DropReference => { - self.drop_reference(); + PollFuture::Complete => { + self.complete(); } - PollFuture::Complete(out, is_join_interested) => { - self.complete(out, is_join_interested); + PollFuture::Dealloc => { + self.dealloc(); } - PollFuture::None => (), + PollFuture::Done => (), } } - fn poll_inner(&self) -> PollFuture<T::Output> { - let snapshot = match self.scheduler_view().transition_to_running() { - TransitionToRunning::Ok(snapshot) => snapshot, - TransitionToRunning::DropReference => return PollFuture::DropReference, - }; + /// Poll the task and cancel it if necessary. This takes ownership of a + /// ref-count. + /// + /// If the return value is Notified, the caller is given ownership of two + /// ref-counts. + /// + /// If the return value is Complete, the caller is given ownership of a + /// single ref-count, which should be passed on to `complete`. + /// + /// If the return value is Dealloc, then this call consumed the last + /// ref-count and the caller should call `dealloc`. + /// + /// Otherwise the ref-count is consumed and the caller should not access + /// `self` again. + fn poll_inner(&self) -> PollFuture { + use super::state::{TransitionToIdle, TransitionToRunning}; + + match self.header().state.transition_to_running() { + TransitionToRunning::Success => { + let waker_ref = waker_ref::<T, S>(self.header()); + let cx = Context::from_waker(&*waker_ref); + let res = poll_future(&self.core().stage, cx); + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollFuture::Complete; + } - // The transition to `Running` done above ensures that a lock on the - // future has been obtained. This also ensures the `*mut T` pointer - // contains the future (as opposed to the output) and is initialized. + match self.header().state.transition_to_idle() { + TransitionToIdle::Ok => PollFuture::Done, + TransitionToIdle::OkNotified => PollFuture::Notified, + TransitionToIdle::OkDealloc => PollFuture::Dealloc, + TransitionToIdle::Cancelled => { + // The transition to idle failed because the task was + // cancelled during the poll. - let waker_ref = waker_ref::<T, S>(self.header()); - let cx = Context::from_waker(&*waker_ref); - poll_future(self.header(), &self.core().stage, snapshot, cx) + cancel_task(&self.core().stage); + PollFuture::Complete + } + } + } + TransitionToRunning::Cancelled => { + cancel_task(&self.core().stage); + PollFuture::Complete + } + TransitionToRunning::Failed => PollFuture::Done, + TransitionToRunning::Dealloc => PollFuture::Dealloc, + } + } + + /// Forcibly shutdown the task + /// + /// Attempt to transition to `Running` in order to forcibly shutdown the + /// task. If the task is currently running or in a state of completion, then + /// there is nothing further to do. When the task completes running, it will + /// notice the `CANCELLED` bit and finalize the task. + pub(super) fn shutdown(self) { + if !self.header().state.transition_to_shutdown() { + // The task is concurrently running. No further work needed. + self.drop_reference(); + return; + } + + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(&self.core().stage); + self.complete(); } pub(super) fn dealloc(self) { @@ -95,7 +149,6 @@ where // Check causality self.core().stage.with_mut(drop); - self.core().scheduler.with_mut(drop); unsafe { drop(Box::from_raw(self.cell.as_ptr())); @@ -112,6 +165,8 @@ where } pub(super) fn drop_join_handle_slow(self) { + let mut maybe_panic = None; + // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. if self.header().state.unset_join_interested().is_err() { @@ -120,23 +175,95 @@ where // the scheduler or `JoinHandle`. i.e. if the output remains in the // task structure until the task is deallocated, it may be dropped // by a Waker on any arbitrary thread. - self.core().stage.drop_future_or_output(); + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + + if let Err(panic) = panic { + maybe_panic = Some(panic); + } } // Drop the `JoinHandle` reference, possibly deallocating the task self.drop_reference(); + + if let Some(panic) = maybe_panic { + panic::resume_unwind(panic); + } + } + + /// Remotely abort the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(self) { + if self.header().state.transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } } // ===== waker behavior ===== + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. pub(super) fn wake_by_val(self) { - self.wake_by_ref(); - self.drop_reference(); + use super::state::TransitionToNotifiedByVal; + + match self.header().state.transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } } + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. pub(super) fn wake_by_ref(&self) { - if self.header().state.transition_to_notified() { - self.core().scheduler.schedule(Notified(self.to_task())); + use super::state::TransitionToNotifiedByRef; + + match self.header().state.transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + TransitionToNotifiedByRef::DoNothing => {} } } @@ -151,153 +278,65 @@ where self.header().id.as_ref() } - /// Forcibly shutdown the task - /// - /// Attempt to transition to `Running` in order to forcibly shutdown the - /// task. If the task is currently running or in a state of completion, then - /// there is nothing further to do. When the task completes running, it will - /// notice the `CANCELLED` bit and finalize the task. - pub(super) fn shutdown(self) { - if !self.header().state.transition_to_shutdown() { - // The task is concurrently running. No further work needed. - return; - } - - // By transitioning the lifcycle to `Running`, we have permission to - // drop the future. - let err = cancel_task(&self.core().stage); - self.complete(Err(err), true) - } - // ====== internal ====== - fn complete(self, output: super::Result<T::Output>, is_join_interested: bool) { - if is_join_interested { - // Store the output. The future has already been dropped - // - // Safety: Mutual exclusion is obtained by having transitioned the task - // state -> Running - let stage = &self.core().stage; - stage.store_output(output); - - // Transition to `Complete`, notifying the `JoinHandle` if necessary. - transition_to_complete(self.header(), stage, &self.trailer()); - } + /// Complete the task. This method assumes that the state is RUNNING. + fn complete(self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + + let snapshot = self.header().state.transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. + self.core().stage.drop_future_or_output(); + } else if snapshot.has_join_waker() { + // Notify the join handle. The previous transition obtains the + // lock on the waker cell. + self.trailer().wake_join(); + } + })); // The task has completed execution and will no longer be scheduled. - // - // Attempts to batch a ref-dec with the state transition below. - - if self - .scheduler_view() - .transition_to_terminal(is_join_interested) - { - self.dealloc() - } - } + let num_release = self.release(); - fn to_task(&self) -> Task<S> { - self.scheduler_view().to_task() + if self.header().state.transition_to_terminal(num_release) { + self.dealloc(); + } } -} - -enum TransitionToRunning { - Ok(Snapshot), - DropReference, -} -struct SchedulerView<'a, S> { - header: &'a Header, - scheduler: &'a Scheduler<S>, -} - -impl<'a, S> SchedulerView<'a, S> -where - S: Schedule, -{ - fn to_task(&self) -> Task<S> { - // SAFETY The header is from the same struct containing the scheduler `S` so the cast is safe - unsafe { Task::from_raw(self.header.into()) } - } + /// Release the task from the scheduler. Returns the number of ref-counts + /// that should be decremented. + fn release(&self) -> usize { + // We don't actually increment the ref-count here, but the new task is + // never destroyed, so that's ok. + let me = ManuallyDrop::new(self.get_new_task()); - /// Returns true if the task should be deallocated. - fn transition_to_terminal(&self, is_join_interested: bool) -> bool { - let ref_dec = if self.scheduler.is_bound() { - if let Some(task) = self.scheduler.release(self.to_task()) { - mem::forget(task); - true - } else { - false - } + if let Some(task) = self.core().scheduler.release(&me) { + mem::forget(task); + 2 } else { - false - }; - - // This might deallocate - let snapshot = self - .header - .state - .transition_to_terminal(!is_join_interested, ref_dec); - - snapshot.ref_count() == 0 - } - - fn transition_to_running(&self) -> TransitionToRunning { - // If this is the first time the task is polled, the task will be bound - // to the scheduler, in which case the task ref count must be - // incremented. - let is_not_bound = !self.scheduler.is_bound(); - - // Transition the task to the running state. - // - // A failure to transition here indicates the task has been cancelled - // while in the run queue pending execution. - let snapshot = match self.header.state.transition_to_running(is_not_bound) { - Ok(snapshot) => snapshot, - Err(_) => { - // The task was shutdown while in the run queue. At this point, - // we just hold a ref counted reference. Since we do not have access to it here - // return `DropReference` so the caller drops it. - return TransitionToRunning::DropReference; - } - }; - - if is_not_bound { - // Ensure the task is bound to a scheduler instance. Since this is - // the first time polling the task, a scheduler instance is pulled - // from the local context and assigned to the task. - // - // The scheduler maintains ownership of the task and responds to - // `wake` calls. - // - // The task reference count has been incremented. - // - // Safety: Since we have unique access to the task so that we can - // safely call `bind_scheduler`. - self.scheduler.bind_scheduler(self.to_task()); + 1 } - TransitionToRunning::Ok(snapshot) } -} - -/// Transitions the task's lifecycle to `Complete`. Notifies the -/// `JoinHandle` if it still has interest in the completion. -fn transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer) -where - T: Future, -{ - // Transition the task's lifecycle to `Complete` and get a snapshot of - // the task's sate. - let snapshot = header.state.transition_to_complete(); - if !snapshot.is_join_interested() { - // The `JoinHandle` is not interested in the output of this task. It - // is our responsibility to drop the output. - stage.drop_future_or_output(); - } else if snapshot.has_join_waker() { - // Notify the join handle. The previous transition obtains the - // lock on the waker cell. - trailer.wake_join(); + /// Create a new task that holds its own ref-count. + /// + /// # Safety + /// + /// Any use of `self` after this call must ensure that a ref-count to the + /// task holds the task alive until after the use of `self`. Passing the + /// returned Task to any method on `self` is unsound if dropping the Task + /// could drop `self` before the call on `self` returned. + fn get_new_task(&self) -> Task<S> { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { Task::from_raw(self.cell.cast()) } } } @@ -379,73 +418,62 @@ fn set_join_waker( res } -enum PollFuture<T> { - Complete(Result<T, JoinError>, bool), - DropReference, +enum PollFuture { + Complete, Notified, - None, + Done, + Dealloc, } -fn cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError { +/// Cancel the task and store the appropriate error in the stage field. +fn cancel_task<T: Future>(stage: &CoreStage<T>) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { stage.drop_future_or_output(); })); - if let Err(err) = res { - // Dropping the future panicked, complete the join - // handle with the panic to avoid dropping the panic - // on the ground. - JoinError::panic(err) - } else { - JoinError::cancelled() + match res { + Ok(()) => { + stage.store_output(Err(JoinError::cancelled())); + } + Err(panic) => { + stage.store_output(Err(JoinError::panic(panic))); + } } } -fn poll_future<T: Future>( - header: &Header, - core: &CoreStage<T>, - snapshot: Snapshot, - cx: Context<'_>, -) -> PollFuture<T::Output> { - if snapshot.is_cancelled() { - PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested()) - } else { - let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { - struct Guard<'a, T: Future> { - core: &'a CoreStage<T>, - } - - impl<T: Future> Drop for Guard<'_, T> { - fn drop(&mut self) { - self.core.drop_future_or_output(); - } +/// Poll the future. If the future completes, the output is written to the +/// stage field. +fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage<T>, + } + impl<'a, T: Future> Drop for Guard<'a, T> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + self.core.drop_future_or_output(); } + } + let guard = Guard { core }; + let res = guard.core.poll(cx); + mem::forget(guard); + res + })); - let guard = Guard { core }; - - let res = guard.core.poll(cx); + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(JoinError::panic(panic)), + }; - // prevent the guard from dropping the future - mem::forget(guard); + // Catch and ignore panics if the future panics on drop. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + core.store_output(output); + })); - res - })); - match res { - Ok(Poll::Pending) => match header.state.transition_to_idle() { - Ok(snapshot) => { - if snapshot.is_notified() { - PollFuture::Notified - } else { - PollFuture::None - } - } - Err(_) => PollFuture::Complete(Err(cancel_task(core)), true), - }, - Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()), - Err(err) => { - PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested()) - } - } - } + Poll::Ready(()) } diff --git a/src/runtime/task/inject.rs b/src/runtime/task/inject.rs new file mode 100644 index 0000000..d1f0aee --- /dev/null +++ b/src/runtime/task/inject.rs @@ -0,0 +1,220 @@ +//! Inject queue used to send wakeups to a work-stealing scheduler + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::runtime::task; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{Acquire, Release}; + +/// Growable, MPMC queue used to inject new tasks into the scheduler and as an +/// overflow queue when the local, fixed-size, array queue overflows. +pub(crate) struct Inject<T: 'static> { + /// Pointers to the head and tail of the queue + pointers: Mutex<Pointers>, + + /// Number of pending tasks in the queue. This helps prevent unnecessary + /// locking in the hot path. + len: AtomicUsize, + + _p: PhantomData<T>, +} + +struct Pointers { + /// True if the queue is closed + is_closed: bool, + + /// Linked-list head + head: Option<NonNull<task::Header>>, + + /// Linked-list tail + tail: Option<NonNull<task::Header>>, +} + +unsafe impl<T> Send for Inject<T> {} +unsafe impl<T> Sync for Inject<T> {} + +impl<T: 'static> Inject<T> { + pub(crate) fn new() -> Inject<T> { + Inject { + pointers: Mutex::new(Pointers { + is_closed: false, + head: None, + tail: None, + }), + len: AtomicUsize::new(0), + _p: PhantomData, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Close the injection queue, returns `true` if the queue is open when the + /// transition is made. + pub(crate) fn close(&self) -> bool { + let mut p = self.pointers.lock(); + + if p.is_closed { + return false; + } + + p.is_closed = true; + true + } + + pub(crate) fn is_closed(&self) -> bool { + self.pointers.lock().is_closed + } + + pub(crate) fn len(&self) -> usize { + self.len.load(Acquire) + } + + /// Pushes a value into the queue. + /// + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified<T>) { + // Acquire queue lock + let mut p = self.pointers.lock(); + + if p.is_closed { + return; + } + + // safety: only mutated with the lock held + let len = unsafe { self.len.unsync_load() }; + let task = task.into_raw(); + + // The next pointer should already be null + debug_assert!(get_next(task).is_none()); + + if let Some(tail) = p.tail { + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(tail, Some(task)); + } else { + p.head = Some(task); + } + + p.tail = Some(task); + + self.len.store(len + 1, Release); + } + + /// Pushes several values into the queue. + #[inline] + pub(crate) fn push_batch<I>(&self, mut iter: I) + where + I: Iterator<Item = task::Notified<T>>, + { + let first = match iter.next() { + Some(first) => first.into_raw(), + None => return, + }; + + // Link up all the tasks. + let mut prev = first; + let mut counter = 1; + + // We are going to be called with an `std::iter::Chain`, and that + // iterator overrides `for_each` to something that is easier for the + // compiler to optimize than a loop. + iter.for_each(|next| { + let next = next.into_raw(); + + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(prev, Some(next)); + prev = next; + counter += 1; + }); + + // Now that the tasks are linked together, insert them into the + // linked list. + self.push_batch_inner(first, prev, counter); + } + + /// Insert several tasks that have been linked together into the queue. + /// + /// The provided head and tail may be be the same task. In this case, a + /// single task is inserted. + #[inline] + fn push_batch_inner( + &self, + batch_head: NonNull<task::Header>, + batch_tail: NonNull<task::Header>, + num: usize, + ) { + debug_assert!(get_next(batch_tail).is_none()); + + let mut p = self.pointers.lock(); + + if let Some(tail) = p.tail { + set_next(tail, Some(batch_head)); + } else { + p.head = Some(batch_head); + } + + p.tail = Some(batch_tail); + + // Increment the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + let len = unsafe { self.len.unsync_load() }; + + self.len.store(len + num, Release); + } + + pub(crate) fn pop(&self) -> Option<task::Notified<T>> { + // Fast path, if len == 0, then there are no values + if self.is_empty() { + return None; + } + + let mut p = self.pointers.lock(); + + // It is possible to hit null here if another thread popped the last + // task between us checking `len` and acquiring the lock. + let task = p.head?; + + p.head = get_next(task); + + if p.head.is_none() { + p.tail = None; + } + + set_next(task, None); + + // Decrement the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + self.len + .store(unsafe { self.len.unsync_load() } - 1, Release); + + // safety: a `Notified` is pushed into the queue and now it is popped! + Some(unsafe { task::Notified::from_raw(task) }) + } +} + +impl<T: 'static> Drop for Inject<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { + unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } +} + +fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { + unsafe { + header.as_ref().set_next(val); + } +} diff --git a/src/runtime/task/join.rs b/src/runtime/task/join.rs index dedfb38..2fe40a7 100644 --- a/src/runtime/task/join.rs +++ b/src/runtime/task/join.rs @@ -192,7 +192,7 @@ impl<T> JoinHandle<T> { /// ``` pub fn abort(&self) { if let Some(raw) = self.raw { - raw.shutdown(); + raw.remote_abort(); } } } diff --git a/src/runtime/task/list.rs b/src/runtime/task/list.rs new file mode 100644 index 0000000..edd3c4f --- /dev/null +++ b/src/runtime/task/list.rs @@ -0,0 +1,297 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::Mutex; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; +use crate::util::linked_list::{Link, LinkedList}; + +use std::marker::PhantomData; + +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + +pub(crate) struct OwnedTasks<S: 'static> { + inner: Mutex<OwnedTasksInner<S>>, + id: u64, +} +pub(crate) struct LocalOwnedTasks<S: 'static> { + inner: UnsafeCell<OwnedTasksInner<S>>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, +} +struct OwnedTasksInner<S: 'static> { + list: LinkedList<Task<S>, <Task<S> as Link>::Target>, + closed: bool, +} + +impl<S: 'static> OwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + } + } + + /// Bind the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(notified); + task.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } + } + + /// Assert that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + /// Shut down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + // The first iteration of the loop was unrolled so it can set the + // closed bool. + let first_task = { + let mut lock = self.inner.lock(); + lock.closed = true; + lock.list.pop_back() + }; + match first_task { + Some(task) => task.shutdown(), + None => return, + } + + loop { + let task = match self.inner.lock().list.pop_back() { + Some(task) => task, + None => return, + }; + + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } +} + +impl<S: 'static> LocalOwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + _not_send_or_sync: PhantomData, + } + } + + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + if self.is_closed() { + drop(notified); + task.shutdown(); + (join, None) + } else { + self.with_inner(|inner| { + inner.list.push_front(task); + }); + (join, Some(notified)) + } + } + + /// Shut down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + self.with_inner(|inner| inner.closed = true); + + while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + self.with_inner(|inner| + // safety: We just checked that the provided task is not in some + // other linked list. + unsafe { inner.list.remove(task.header().into()) }) + } + + /// Assert that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + #[inline] + fn with_inner<F, T>(&self, f: F) -> T + where + F: FnOnce(&mut OwnedTasksInner<S>) -> T, + { + // safety: This type is not Sync, so concurrent calls of this method + // can't happen. Furthermore, all uses of this method in this file make + // sure that they don't call `with_inner` recursively. + self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) + } + + pub(crate) fn is_closed(&self) -> bool { + self.with_inner(|inner| inner.closed) + } + + pub(crate) fn is_empty(&self) -> bool { + self.with_inner(|inner| inner.list.is_empty()) + } +} + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 58b8c2a..cc3910d 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -1,6 +1,143 @@ +//! The task module. +//! +//! The task module contains the code that manages spawned tasks and provides a +//! safe API for the rest of the runtime to use. Each task in a runtime is +//! stored in an OwnedTasks or LocalOwnedTasks object. +//! +//! # Task reference types +//! +//! A task is usually referenced by multiple handles, and there are several +//! types of handles. +//! +//! * OwnedTask - tasks stored in an OwnedTasks or LocalOwnedTasks are of this +//! reference type. +//! +//! * JoinHandle - each task has a JoinHandle that allows access to the output +//! of the task. +//! +//! * Waker - every waker for a task has this reference type. There can be any +//! number of waker references. +//! +//! * Notified - tracks whether the task is notified. +//! +//! * Unowned - this task reference type is used for tasks not stored in any +//! runtime. Mainly used for blocking tasks, but also in tests. +//! +//! The task uses a reference count to keep track of how many active references +//! exist. The Unowned reference type takes up two ref-counts. All other +//! reference types take pu a single ref-count. +//! +//! Besides the waker type, each task has at most one of each reference type. +//! +//! # State +//! +//! The task stores its state in an atomic usize with various bitfields for the +//! necessary information. The state has the following bitfields: +//! +//! * RUNNING - Tracks whether the task is currently being polled or cancelled. +//! This bit functions as a lock around the task. +//! +//! * COMPLETE - Is one once the future has fully completed and has been +//! dropped. Never unset once set. Never set together with RUNNING. +//! +//! * NOTIFIED - Tracks whether a Notified object currently exists. +//! +//! * CANCELLED - Is set to one for tasks that should be cancelled as soon as +//! possible. May take any value for completed tasks. +//! +//! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. +//! +//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! +//! The rest of the bits are used for the ref-count. +//! +//! # Fields in the task +//! +//! The task has various fields. This section describes how and when it is safe +//! to access a field. +//! +//! * The state field is accessed with atomic instructions. +//! +//! * The OwnedTask reference has exclusive access to the `owned` field. +//! +//! * The Notified reference has exclusive access to the `queue_next` field. +//! +//! * The `owner_id` field can be set as part of construction of the task, but +//! is otherwise immutable and anyone can access the field immutably without +//! synchronization. +//! +//! * If COMPLETE is one, then the JoinHandle has exclusive access to the +//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as +//! a lock for the stage field, and it can be accessed only by the thread +//! that set RUNNING to one. +//! +//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the +//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the +//! thread that set COMPLETE to one has exclusive access to the join handle +//! waker. +//! +//! All other fields are immutable and can be accessed immutably without +//! synchronization by anyone. +//! +//! # Safety +//! +//! This section goes through various situations and explains why the API is +//! safe in that situation. +//! +//! ## Polling or dropping the future +//! +//! Any mutable access to the future happens after obtaining a lock by modifying +//! the RUNNING field, so exclusive access is ensured. +//! +//! When the task completes, exclusive access to the output is transferred to +//! the JoinHandle. If the JoinHandle is already dropped when the transition to +//! complete happens, the thread performing that transition retains exclusive +//! access to the output and should immediately drop it. +//! +//! ## Non-Send futures +//! +//! If a future is not Send, then it is bound to a LocalOwnedTasks. The future +//! will only ever be polled or dropped given a LocalNotified or inside a call +//! to LocalOwnedTasks::shutdown_all. In either case, it is guaranteed that the +//! future is on the right thread. +//! +//! If the task is never removed from the LocalOwnedTasks, then it is leaked, so +//! there is no risk that the task is dropped on some other thread when the last +//! ref-count drops. +//! +//! ## Non-Send output +//! +//! When a task completes, the output is placed in the stage of the task. Then, +//! a transition that sets COMPLETE to true is performed, and the value of +//! JOIN_INTEREST when this transition happens is read. +//! +//! If JOIN_INTEREST is zero when the transition to COMPLETE happens, then the +//! output is immediately dropped. +//! +//! If JOIN_INTEREST is one when the transition to COMPLETE happens, then the +//! JoinHandle is responsible for cleaning up the output. If the output is not +//! Send, then this happens: +//! +//! 1. The output is created on the thread that the future was polled on. Since +//! only non-Send futures can have non-Send output, the future was polled on +//! the thread that the future was spawned from. +//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! JoinHandle is also on the thread that the future was spawned from. +//! 3. Thus, the JoinHandle will not move the output across threads when it +//! takes or drops the output. +//! +//! ## Recursive poll/shutdown +//! +//! Calling poll from inside a shutdown call or vice-versa is not prevented by +//! the API exposed by the task module, so this has to be safe. In either case, +//! the lock in the RUNNING bitfield makes the inner call return immediately. If +//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the +//! poll call will notice it when the poll finishes, and the task is cancelled +//! at that point. + mod core; use self::core::Cell; -pub(crate) use self::core::Header; +use self::core::Header; mod error; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 @@ -9,10 +146,18 @@ pub use self::error::JoinError; mod harness; use self::harness::Harness; +cfg_rt_multi_thread! { + mod inject; + pub(super) use self::inject::Inject; +} + mod join; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::join::JoinHandle; +mod list; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; + mod raw; use self::raw::RawTask; @@ -21,11 +166,6 @@ use self::state::State; mod waker; -cfg_rt_multi_thread! { - mod stack; - pub(crate) use self::stack::TransferStack; -} - use crate::future::Future; use crate::util::linked_list; @@ -43,30 +183,43 @@ pub(crate) struct Task<S: 'static> { unsafe impl<S> Send for Task<S> {} unsafe impl<S> Sync for Task<S> {} -/// A task was notified +/// A task was notified. #[repr(transparent)] pub(crate) struct Notified<S: 'static>(Task<S>); +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. unsafe impl<S: Schedule> Send for Notified<S> {} unsafe impl<S: Schedule> Sync for Notified<S> {} +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified<S: 'static> { + task: Task<S>, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +// safety: This type can only be created given a Send task. +unsafe impl<S> Send for UnownedTask<S> {} +unsafe impl<S> Sync for UnownedTask<S> {} + /// Task result sent back pub(crate) type Result<T> = std::result::Result<T, JoinError>; pub(crate) trait Schedule: Sync + Sized + 'static { - /// Bind a task to the executor. - /// - /// Guaranteed to be called from the thread that called `poll` on the task. - /// The returned `Schedule` instance is associated with the task and is used - /// as `&self` in the other methods on this trait. - fn bind(task: Task<Self>) -> Self; - /// The task has completed work and is ready to be released. The scheduler - /// is free to drop it whenever. + /// should release it immediately and return it. The task module will batch + /// the ref-dec with setting other options. /// - /// If the scheduler can immediately release the task, it should return - /// it as part of the function. This enables the task module to batch - /// the ref-dec with other options. + /// If the scheduler has already released the task, then None is returned. fn release(&self, task: &Task<Self>) -> Option<Task<Self>>; /// Schedule the task @@ -80,71 +233,86 @@ pub(crate) trait Schedule: Sync + Sized + 'static { } cfg_rt! { - /// Create a new task with an associated join handle - pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task<T, S>( + task: T, + scheduler: S + ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) where - T: Future + Send + 'static, S: Schedule, + T: Future + 'static, + T::Output: 'static, { - let raw = RawTask::new::<_, S>(task); - + let raw = RawTask::new::<T, S>(task, scheduler); let task = Task { raw, _p: PhantomData, }; - + let notified = Notified(Task { + raw, + _p: PhantomData, + }); let join = JoinHandle::new(raw); - (Notified(task), join) + (task, notified, join) } -} -cfg_rt! { - /// Create a new `!Send` task with an associated join handle - pub(crate) unsafe fn joinable_local<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) + /// Create a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks use this method. + pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) where - T: Future + 'static, S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, { - let raw = RawTask::new::<_, S>(task); + let (task, notified, join) = new_task(task, scheduler); - let task = Task { - raw, + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, _p: PhantomData, }; + std::mem::forget(task); + std::mem::forget(notified); - let join = JoinHandle::new(raw); - - (Notified(task), join) + (unowned, join) } } impl<S: 'static> Task<S> { - pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { Task { raw: RawTask::from_raw(ptr), _p: PhantomData, } } - pub(crate) fn header(&self) -> &Header { + fn header(&self) -> &Header { self.raw.header() } } +impl<S: 'static> Notified<S> { + fn header(&self) -> &Header { + self.0.header() + } +} + cfg_rt_multi_thread! { impl<S: 'static> Notified<S> { - pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { Notified(Task::from_raw(ptr)) } - - pub(crate) fn header(&self) -> &Header { - self.0.header() - } } impl<S: 'static> Task<S> { - pub(crate) fn into_raw(self) -> NonNull<Header> { + fn into_raw(self) -> NonNull<Header> { let ret = self.header().into(); mem::forget(self); ret @@ -152,7 +320,7 @@ cfg_rt_multi_thread! { } impl<S: 'static> Notified<S> { - pub(crate) fn into_raw(self) -> NonNull<Header> { + fn into_raw(self) -> NonNull<Header> { self.0.into_raw() } } @@ -160,21 +328,55 @@ cfg_rt_multi_thread! { impl<S: Schedule> Task<S> { /// Pre-emptively cancel the task as part of the shutdown process. - pub(crate) fn shutdown(&self) { - self.raw.shutdown(); + pub(crate) fn shutdown(self) { + let raw = self.raw; + mem::forget(self); + raw.shutdown(); } } -impl<S: Schedule> Notified<S> { +impl<S: Schedule> LocalNotified<S> { /// Run the task pub(crate) fn run(self) { - self.0.raw.poll(); + let raw = self.task.raw; mem::forget(self); + raw.poll(); + } +} + +impl<S: Schedule> UnownedTask<S> { + // Used in test of the inject queue. + #[cfg(test)] + pub(super) fn into_notified(self) -> Notified<S> { + Notified(self.into_task()) + } + + fn into_task(self) -> Task<S> { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + let raw = self.raw; + mem::forget(self); + + // Poll the task + raw.poll(); + // Decrement our extra ref-count + raw.header().state.ref_dec(); } - /// Pre-emptively cancel the task as part of the shutdown process. pub(crate) fn shutdown(self) { - self.0.shutdown(); + self.into_task().shutdown() } } @@ -188,6 +390,16 @@ impl<S: 'static> Drop for Task<S> { } } +impl<S: 'static> Drop for UnownedTask<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + impl<S> fmt::Debug for Task<S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "Task({:p})", self.header()) diff --git a/src/runtime/task/raw.rs b/src/runtime/task/raw.rs index a9cd4e6..8c2c3f7 100644 --- a/src/runtime/task/raw.rs +++ b/src/runtime/task/raw.rs @@ -22,6 +22,9 @@ pub(super) struct Vtable { /// The join handle has been dropped pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), + /// The task is remotely aborted + pub(super) remote_abort: unsafe fn(NonNull<Header>), + /// Scheduler is being shutdown pub(super) shutdown: unsafe fn(NonNull<Header>), } @@ -33,17 +36,18 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { dealloc: dealloc::<T, S>, try_read_output: try_read_output::<T, S>, drop_join_handle_slow: drop_join_handle_slow::<T, S>, + remote_abort: remote_abort::<T, S>, shutdown: shutdown::<T, S>, } } impl RawTask { - pub(super) fn new<T, S>(task: T) -> RawTask + pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } @@ -89,6 +93,11 @@ impl RawTask { let vtable = self.header().vtable; unsafe { (vtable.shutdown)(self.ptr) } } + + pub(super) fn remote_abort(self) { + let vtable = self.header().vtable; + unsafe { (vtable.remote_abort)(self.ptr) } + } } impl Clone for RawTask { @@ -125,6 +134,11 @@ unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { harness.drop_join_handle_slow() } +unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.remote_abort() +} + unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { let harness = Harness::<T, S>::from_raw(ptr); harness.shutdown() diff --git a/src/runtime/task/stack.rs b/src/runtime/task/stack.rs deleted file mode 100644 index 9dd8d3f..0000000 --- a/src/runtime/task/stack.rs +++ /dev/null @@ -1,83 +0,0 @@ -use crate::loom::sync::atomic::AtomicPtr; -use crate::runtime::task::{Header, Task}; - -use std::marker::PhantomData; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; - -/// Concurrent stack of tasks, used to pass ownership of a task from one worker -/// to another. -pub(crate) struct TransferStack<T: 'static> { - head: AtomicPtr<Header>, - _p: PhantomData<T>, -} - -impl<T: 'static> TransferStack<T> { - pub(crate) fn new() -> TransferStack<T> { - TransferStack { - head: AtomicPtr::new(ptr::null_mut()), - _p: PhantomData, - } - } - - pub(crate) fn push(&self, task: Task<T>) { - let task = task.into_raw(); - - // We don't care about any memory associated w/ setting the `head` - // field, just the current value. - // - // The compare-exchange creates a release sequence. - let mut curr = self.head.load(Relaxed); - - loop { - unsafe { - task.as_ref() - .stack_next - .with_mut(|ptr| *ptr = NonNull::new(curr)) - }; - - let res = self - .head - .compare_exchange(curr, task.as_ptr() as *mut _, Release, Relaxed); - - match res { - Ok(_) => return, - Err(actual) => { - curr = actual; - } - } - } - } - - pub(crate) fn drain(&self) -> impl Iterator<Item = Task<T>> { - struct Iter<T: 'static>(Option<NonNull<Header>>, PhantomData<T>); - - impl<T: 'static> Iterator for Iter<T> { - type Item = Task<T>; - - fn next(&mut self) -> Option<Task<T>> { - let task = self.0?; - - // Move the cursor forward - self.0 = unsafe { task.as_ref().stack_next.with(|ptr| *ptr) }; - - // Return the task - unsafe { Some(Task::from_raw(task)) } - } - } - - impl<T: 'static> Drop for Iter<T> { - fn drop(&mut self) { - use std::process; - - if self.0.is_some() { - // we have bugs - process::abort(); - } - } - } - - let ptr = self.head.swap(ptr::null_mut(), Acquire); - Iter(NonNull::new(ptr), PhantomData) - } -} diff --git a/src/runtime/task/state.rs b/src/runtime/task/state.rs index 1f08d6d..059a7f9 100644 --- a/src/runtime/task/state.rs +++ b/src/runtime/task/state.rs @@ -54,22 +54,52 @@ const REF_ONE: usize = 1 << REF_COUNT_SHIFT; /// State a task is initialized with /// -/// A task is initialized with two references: one for the scheduler and one for -/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is -/// set. A new task is immediately pushed into the run queue for execution and -/// starts with the `NOTIFIED` flag set. -const INITIAL_STATE: usize = (REF_ONE * 2) | JOIN_INTEREST | NOTIFIED; +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; + +#[must_use] +pub(super) enum TransitionToRunning { + Success, + Cancelled, + Failed, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToIdle { + Ok, + OkNotified, + OkDealloc, + Cancelled, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByVal { + DoNothing, + Submit, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByRef { + DoNothing, + Submit, +} /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { /// Return a task's initial state pub(super) fn new() -> State { - // A task is initialized with three references: one for the scheduler, - // one for the `JoinHandle`, one for the task handle made available in - // release. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is - // set. A new task is immediately pushed into the run queue for - // execution and starts with the `NOTIFIED` flag set. + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. State { val: AtomicUsize::new(INITIAL_STATE), } @@ -80,57 +110,72 @@ impl State { Snapshot(self.val.load(Acquire)) } - /// Attempt to transition the lifecycle to `Running`. - /// - /// If `ref_inc` is set, the reference count is also incremented. - /// - /// The `NOTIFIED` bit is always unset. - pub(super) fn transition_to_running(&self, ref_inc: bool) -> UpdateResult { - self.fetch_update(|curr| { - assert!(curr.is_notified()); - - let mut next = curr; + /// Attempt to transition the lifecycle to `Running`. This sets the + /// notified bit to false so notifications during the poll can be detected. + pub(super) fn transition_to_running(&self) -> TransitionToRunning { + self.fetch_update_action(|mut next| { + let action; + assert!(next.is_notified()); if !next.is_idle() { - return None; - } - - if ref_inc { - next.ref_inc(); + // This happens if the task is either currently running or if it + // has already completed, e.g. if it was cancelled during + // shutdown. Consume the ref-count and return. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToRunning::Dealloc; + } else { + action = TransitionToRunning::Failed; + } + } else { + // We are able to lock the RUNNING bit. + next.set_running(); + next.unset_notified(); + + if next.is_cancelled() { + action = TransitionToRunning::Cancelled; + } else { + action = TransitionToRunning::Success; + } } - - next.set_running(); - next.unset_notified(); - Some(next) + (action, Some(next)) }) } /// Transitions the task from `Running` -> `Idle`. /// - /// Returns `Ok` if the transition to `Idle` is successful, `Err` otherwise. - /// In both cases, a snapshot of the state from **after** the transition is - /// returned. - /// + /// Returns `true` if the transition to `Idle` is successful, `false` otherwise. /// The transition to `Idle` fails if the task has been flagged to be /// cancelled. - pub(super) fn transition_to_idle(&self) -> UpdateResult { - self.fetch_update(|curr| { + pub(super) fn transition_to_idle(&self) -> TransitionToIdle { + self.fetch_update_action(|curr| { assert!(curr.is_running()); if curr.is_cancelled() { - return None; + return (TransitionToIdle::Cancelled, None); } let mut next = curr; + let action; next.unset_running(); - if next.is_notified() { - // The caller needs to schedule the task. To do this, it needs a - // waker. The waker requires a ref count. + if !next.is_notified() { + // Polling the future consumes the ref-count of the Notified. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToIdle::OkDealloc; + } else { + action = TransitionToIdle::Ok; + } + } else { + // The caller will schedule a new notification, so we create a + // new ref-count for the notification. Our own ref-count is kept + // for now, and the caller will drop it shortly. next.ref_inc(); + action = TransitionToIdle::OkNotified; } - Some(next) + (action, Some(next)) }) } @@ -146,38 +191,119 @@ impl State { } /// Transition from `Complete` -> `Terminal`, decrementing the reference - /// count by 1. + /// count the specified number of times. /// - /// When `ref_dec` is set, an additional ref count decrement is performed. - /// This is used to batch atomic ops when possible. - pub(super) fn transition_to_terminal(&self, complete: bool, ref_dec: bool) -> Snapshot { - self.fetch_update(|mut snapshot| { - if complete { - snapshot.set_complete(); - } else { - assert!(snapshot.is_complete()); - } + /// Returns true if the task should be deallocated. + pub(super) fn transition_to_terminal(&self, count: usize) -> bool { + let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); + assert!( + prev.ref_count() >= count, + "current: {}, sub: {}", + prev.ref_count(), + count + ); + prev.ref_count() == count + } + + /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. + pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { + self.fetch_update_action(|mut snapshot| { + let action; + + if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit anything as the thread currently running the + // future is responsible for that. + snapshot.set_notified(); + snapshot.ref_dec(); - // Decrement the primary handle - snapshot.ref_dec(); + // The thread that set the running bit also holds a ref-count. + assert!(snapshot.ref_count() > 0); - if ref_dec { - // Decrement a second time + action = TransitionToNotifiedByVal::DoNothing; + } else if snapshot.is_complete() || snapshot.is_notified() { + // We do not need to submit any notifications, but we have to + // decrement the ref-count. snapshot.ref_dec(); + + if snapshot.ref_count() == 0 { + action = TransitionToNotifiedByVal::Dealloc; + } else { + action = TransitionToNotifiedByVal::DoNothing; + } + } else { + // We create a new notified that we can submit. The caller + // retains ownership of the ref-count they passed in. + snapshot.set_notified(); + snapshot.ref_inc(); + action = TransitionToNotifiedByVal::Submit; } - Some(snapshot) + (action, Some(snapshot)) }) - .unwrap() } /// Transitions the state to `NOTIFIED`. + pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_complete() || snapshot.is_notified() { + // There is nothing to do in this case. + (TransitionToNotifiedByRef::DoNothing, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit as the thread currently running the future is + // responsible for that. + snapshot.set_notified(); + (TransitionToNotifiedByRef::DoNothing, Some(snapshot)) + } else { + // The task is idle and not notified. We should submit a + // notification. + snapshot.set_notified(); + snapshot.ref_inc(); + (TransitionToNotifiedByRef::Submit, Some(snapshot)) + } + }) + } + + /// Set the cancelled bit and transition the state to `NOTIFIED` if idle. /// /// Returns `true` if the task needs to be submitted to the pool for /// execution - pub(super) fn transition_to_notified(&self) -> bool { - let prev = Snapshot(self.val.fetch_or(NOTIFIED, AcqRel)); - prev.will_need_queueing() + pub(super) fn transition_to_notified_and_cancel(&self) -> bool { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_cancelled() || snapshot.is_complete() { + // Aborts to completed or cancelled tasks are no-ops. + (false, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as cancelled. The thread + // running the task will notice the cancelled bit when it + // stops polling and it will kill the task. + // + // The set_notified() call is not strictly necessary but it will + // in some cases let a wake_by_ref call return without having + // to perform a compare_exchange. + snapshot.set_notified(); + snapshot.set_cancelled(); + (false, Some(snapshot)) + } else { + // The task is idle. We set the cancelled and notified bits and + // submit a notification if the notified bit was not already + // set. + snapshot.set_cancelled(); + if !snapshot.is_notified() { + snapshot.set_notified(); + snapshot.ref_inc(); + (true, Some(snapshot)) + } else { + (false, Some(snapshot)) + } + } + }) } /// Set the `CANCELLED` bit and attempt to transition to `Running`. @@ -191,17 +317,11 @@ impl State { if snapshot.is_idle() { snapshot.set_running(); - - if snapshot.is_notified() { - // If the task is idle and notified, this indicates the task is - // in the run queue and is considered owned by the scheduler. - // The shutdown operation claims ownership of the task, which - // means we need to assign an additional ref-count to the task - // in the queue. - snapshot.ref_inc(); - } } + // If the task was not idle, the thread currently running the task + // will notice the cancelled bit and cancel it once the poll + // completes. snapshot.set_cancelled(); Some(snapshot) }); @@ -317,9 +437,39 @@ impl State { /// Returns `true` if the task should be released. pub(super) fn ref_dec(&self) -> bool { let prev = Snapshot(self.val.fetch_sub(REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 1); prev.ref_count() == 1 } + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 2); + prev.ref_count() == 2 + } + + fn fetch_update_action<F, T>(&self, mut f: F) -> T + where + F: FnMut(Snapshot) -> (T, Option<Snapshot>), + { + let mut curr = self.load(); + + loop { + let (output, next) = f(curr); + let next = match next { + Some(next) => next, + None => return output, + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return output, + Err(actual) => curr = Snapshot(actual), + } + } + } + fn fetch_update<F>(&self, mut f: F) -> Result<Snapshot, Snapshot> where F: FnMut(Snapshot) -> Option<Snapshot>, @@ -359,6 +509,10 @@ impl Snapshot { self.0 &= !NOTIFIED } + fn set_notified(&mut self) { + self.0 |= NOTIFIED + } + pub(super) fn is_running(self) -> bool { self.0 & RUNNING == RUNNING } @@ -379,10 +533,6 @@ impl Snapshot { self.0 |= CANCELLED; } - fn set_complete(&mut self) { - self.0 |= COMPLETE; - } - /// Returns `true` if the task's future has completed execution. pub(super) fn is_complete(self) -> bool { self.0 & COMPLETE == COMPLETE @@ -421,10 +571,6 @@ impl Snapshot { assert!(self.ref_count() > 0); self.0 -= REF_ONE } - - fn will_need_queueing(self) -> bool { - !self.is_notified() && self.is_idle() - } } impl fmt::Debug for State { diff --git a/src/runtime/tests/loom_local.rs b/src/runtime/tests/loom_local.rs new file mode 100644 index 0000000..d9a07a4 --- /dev/null +++ b/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/src/runtime/tests/loom_oneshot.rs b/src/runtime/tests/loom_oneshot.rs index c126fe4..87eb638 100644 --- a/src/runtime/tests/loom_oneshot.rs +++ b/src/runtime/tests/loom_oneshot.rs @@ -1,7 +1,6 @@ +use crate::loom::sync::{Arc, Mutex}; use loom::sync::Notify; -use std::sync::{Arc, Mutex}; - pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) { let inner = Arc::new(Inner { notify: Notify::new(), @@ -31,7 +30,7 @@ struct Inner<T> { impl<T> Sender<T> { pub(crate) fn send(self, value: T) { - *self.inner.value.lock().unwrap() = Some(value); + *self.inner.value.lock() = Some(value); self.inner.notify.notify(); } } @@ -39,7 +38,7 @@ impl<T> Sender<T> { impl<T> Receiver<T> { pub(crate) fn recv(self) -> T { loop { - if let Some(v) = self.inner.value.lock().unwrap().take() { + if let Some(v) = self.inner.value.lock().take() { return v; } diff --git a/src/runtime/tests/loom_pool.rs b/src/runtime/tests/loom_pool.rs index 06ad641..b3ecd43 100644 --- a/src/runtime/tests/loom_pool.rs +++ b/src/runtime/tests/loom_pool.rs @@ -11,7 +11,7 @@ use crate::{spawn, task}; use tokio_test::assert_ok; use loom::sync::atomic::{AtomicBool, AtomicUsize}; -use loom::sync::{Arc, Mutex}; +use loom::sync::Arc; use pin_project_lite::pin_project; use std::future::Future; @@ -19,6 +19,57 @@ use std::pin::Pin; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::task::{Context, Poll}; +mod atomic_take { + use loom::sync::atomic::AtomicBool; + use std::mem::MaybeUninit; + use std::sync::atomic::Ordering::SeqCst; + + pub(super) struct AtomicTake<T> { + inner: MaybeUninit<T>, + taken: AtomicBool, + } + + impl<T> AtomicTake<T> { + pub(super) fn new(value: T) -> Self { + Self { + inner: MaybeUninit::new(value), + taken: AtomicBool::new(false), + } + } + + pub(super) fn take(&self) -> Option<T> { + // safety: Only one thread will see the boolean change from false + // to true, so that thread is able to take the value. + match self.taken.fetch_or(true, SeqCst) { + false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) }, + true => None, + } + } + } + + impl<T> Drop for AtomicTake<T> { + fn drop(&mut self) { + drop(self.take()); + } + } +} + +#[derive(Clone)] +struct AtomicOneshot<T> { + value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>, +} +impl<T> AtomicOneshot<T> { + fn new(sender: oneshot::Sender<T>) -> Self { + Self { + value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)), + } + } + + fn assert_send(&self, value: T) { + self.value.take().unwrap().send(value); + } +} + /// Tests are divided into groups to make the runs faster on CI. mod group_a { use super::*; @@ -52,7 +103,7 @@ mod group_a { let c1 = Arc::new(AtomicUsize::new(0)); let (tx, rx) = oneshot::channel(); - let tx1 = Arc::new(Mutex::new(Some(tx))); + let tx1 = AtomicOneshot::new(tx); // Spawn a task let c2 = c1.clone(); @@ -60,7 +111,7 @@ mod group_a { pool.spawn(track(async move { spawn(track(async move { if 1 == c1.fetch_add(1, Relaxed) { - tx1.lock().unwrap().take().unwrap().send(()); + tx1.assert_send(()); } })); })); @@ -69,7 +120,7 @@ mod group_a { pool.spawn(track(async move { spawn(track(async move { if 1 == c2.fetch_add(1, Relaxed) { - tx2.lock().unwrap().take().unwrap().send(()); + tx2.assert_send(()); } })); })); @@ -119,7 +170,7 @@ mod group_b { let (block_tx, block_rx) = oneshot::channel(); let (done_tx, done_rx) = oneshot::channel(); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let done_tx = AtomicOneshot::new(done_tx); pool.spawn(track(async move { crate::task::block_in_place(move || { @@ -136,7 +187,7 @@ mod group_b { pool.spawn(track(async move { if NUM == cnt.fetch_add(1, Relaxed) + 1 { - done_tx.lock().unwrap().take().unwrap().send(()); + done_tx.assert_send(()); } })); } @@ -159,23 +210,6 @@ mod group_b { } #[test] - fn pool_shutdown() { - loom::model(|| { - let pool = mk_pool(2); - - pool.spawn(track(async move { - gated2(true).await; - })); - - pool.spawn(track(async move { - gated2(false).await; - })); - - drop(pool); - }); - } - - #[test] fn join_output() { loom::model(|| { let rt = mk_pool(1); @@ -223,10 +257,6 @@ mod group_b { }); }); } -} - -mod group_c { - use super::*; #[test] fn shutdown_with_notification() { @@ -255,6 +285,27 @@ mod group_c { } } +mod group_c { + use super::*; + + #[test] + fn pool_shutdown() { + loom::model(|| { + let pool = mk_pool(2); + + pool.spawn(track(async move { + gated2(true).await; + })); + + pool.spawn(track(async move { + gated2(false).await; + })); + + drop(pool); + }); + } +} + mod group_d { use super::*; @@ -266,17 +317,17 @@ mod group_d { let c1 = Arc::new(AtomicUsize::new(0)); let (done_tx, done_rx) = oneshot::channel(); - let done_tx1 = Arc::new(Mutex::new(Some(done_tx))); + let done_tx1 = AtomicOneshot::new(done_tx); + let done_tx2 = done_tx1.clone(); // Spawn a task let c2 = c1.clone(); - let done_tx2 = done_tx1.clone(); pool.spawn(track(async move { gated().await; gated().await; if 1 == c1.fetch_add(1, Relaxed) { - done_tx1.lock().unwrap().take().unwrap().send(()); + done_tx1.assert_send(()); } })); @@ -286,7 +337,7 @@ mod group_d { gated().await; if 1 == c2.fetch_add(1, Relaxed) { - done_tx2.lock().unwrap().take().unwrap().send(()); + done_tx2.assert_send(()); } })); diff --git a/src/runtime/tests/loom_queue.rs b/src/runtime/tests/loom_queue.rs index de02610..a1ed171 100644 --- a/src/runtime/tests/loom_queue.rs +++ b/src/runtime/tests/loom_queue.rs @@ -1,5 +1,6 @@ +use crate::runtime::blocking::NoopSchedule; use crate::runtime::queue; -use crate::runtime::task::{self, Schedule, Task}; +use crate::runtime::task::Inject; use loom::thread; @@ -7,7 +8,7 @@ use loom::thread; fn basic() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { let (_, mut local) = queue::local(); @@ -30,7 +31,7 @@ fn basic() { for _ in 0..2 { for _ in 0..2 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -39,7 +40,7 @@ fn basic() { } // Push another task - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); while local.pop().is_some() { @@ -61,7 +62,7 @@ fn basic() { fn steal_overflow() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { let (_, mut local) = queue::local(); @@ -81,7 +82,7 @@ fn steal_overflow() { let mut n = 0; // push a task, pop a task - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if local.pop().is_some() { @@ -89,7 +90,7 @@ fn steal_overflow() { } for _ in 0..6 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -111,7 +112,7 @@ fn steal_overflow() { fn multi_stealer() { const NUM_TASKS: usize = 5; - fn steal_tasks(steal: queue::Steal<Runtime>) -> usize { + fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize { let (_, mut local) = queue::local(); if steal.steal_into(&mut local).is_none() { @@ -129,11 +130,11 @@ fn multi_stealer() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); // Push work for _ in 0..NUM_TASKS { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -166,14 +167,14 @@ fn chained_steal() { loom::model(|| { let (s1, mut l1) = queue::local(); let (s2, mut l2) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); // Load up some tasks for _ in 0..4 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l1.push_back(task, &inject); - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l2.push_back(task, &inject); } @@ -197,20 +198,3 @@ fn chained_steal() { while inject.pop().is_some() {} }); } - -struct Runtime; - -impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - std::mem::forget(task); - Runtime - } - - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { - None - } - - fn schedule(&self, _task: task::Notified<Self>) { - unreachable!(); - } -} diff --git a/src/runtime/tests/mod.rs b/src/runtime/tests/mod.rs index c84ba1b..be36d6f 100644 --- a/src/runtime/tests/mod.rs +++ b/src/runtime/tests/mod.rs @@ -1,5 +1,36 @@ +use self::unowned_wrapper::unowned; + +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + use tracing::Instrument; + let span = tracing::trace_span!("test_span"); + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } +} + cfg_loom! { mod loom_basic_scheduler; + mod loom_local; mod loom_blocking; mod loom_oneshot; mod loom_pool; @@ -10,6 +41,9 @@ cfg_loom! { cfg_not_loom! { mod queue; + #[cfg(not(miri))] + mod task_combinations; + #[cfg(miri)] mod task; } diff --git a/src/runtime/tests/queue.rs b/src/runtime/tests/queue.rs index d228d5d..428b002 100644 --- a/src/runtime/tests/queue.rs +++ b/src/runtime/tests/queue.rs @@ -1,5 +1,5 @@ use crate::runtime::queue; -use crate::runtime::task::{self, Schedule, Task}; +use crate::runtime::task::{self, Inject, Schedule, Task}; use std::thread; use std::time::Duration; @@ -7,10 +7,10 @@ use std::time::Duration; #[test] fn fits_256() { let (_, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..256 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -22,10 +22,10 @@ fn fits_256() { #[test] fn overflow() { let (_, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..257 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -46,10 +46,10 @@ fn overflow() { fn steal_batch() { let (steal1, mut local1) = queue::local(); let (_, mut local2) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..4 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local1.push_back(task, &inject); } @@ -78,7 +78,7 @@ fn stress1() { for _ in 0..NUM_ITER { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { let (_, mut local) = queue::local(); @@ -103,7 +103,7 @@ fn stress1() { for _ in 0..NUM_LOCAL { for _ in 0..NUM_PUSH { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -134,7 +134,7 @@ fn stress2() { for _ in 0..NUM_ITER { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { let (_, mut local) = queue::local(); @@ -158,7 +158,7 @@ fn stress2() { let mut num_pop = 0; for i in 0..NUM_TASKS { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if i % 128 == 0 && local.pop().is_some() { @@ -187,11 +187,6 @@ fn stress2() { struct Runtime; impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - std::mem::forget(task); - Runtime - } - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { None } diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs index 45a3e99..e93a1ef 100644 --- a/src/runtime/tests/task.rs +++ b/src/runtime/tests/task.rs @@ -1,44 +1,185 @@ -use crate::runtime::task::{self, Schedule, Task}; -use crate::util::linked_list::{Link, LinkedList}; +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; use crate::util::TryLock; use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +struct AssertDropHandle { + is_dropped: Arc<AtomicBool>, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc<AtomicBool>, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + +#[test] +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works #[test] -fn create_drop() { - let _ = task::joinable::<_, Runtime>(async { unreachable!() }); +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); } #[test] fn schedule() { with(|rt| { - let (task, _) = task::joinable(async { + rt.spawn(async { crate::task::yield_now().await; }); - rt.schedule(task); - assert_eq!(2, rt.tick()); + rt.shutdown(); }) } #[test] fn shutdown() { with(|rt| { - let (task, _) = task::joinable(async { + rt.spawn(async { loop { crate::task::yield_now().await; } }); - rt.schedule(task); rt.tick_max(1); rt.shutdown(); }) } +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + fn with(f: impl FnOnce(Runtime)) { struct Reset; @@ -51,10 +192,9 @@ fn with(f: impl FnOnce(Runtime)) { let _reset = Reset; let rt = Runtime(Arc::new(Inner { - released: task::TransferStack::new(), + owned: OwnedTasks::new(), core: TryLock::new(Core { queue: VecDeque::new(), - tasks: LinkedList::new(), }), })); @@ -66,18 +206,31 @@ fn with(f: impl FnOnce(Runtime)) { struct Runtime(Arc<Inner>); struct Inner { - released: task::TransferStack<Runtime>, core: TryLock<Core>, + owned: OwnedTasks<Runtime>, } struct Core { queue: VecDeque<task::Notified<Runtime>>, - tasks: LinkedList<Task<Runtime>, <Task<Runtime> as Link>::Target>, } static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); impl Runtime { + fn spawn<T>(&self, future: T) -> JoinHandle<T::Output> + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + fn tick(&self) -> usize { self.tick_max(usize::MAX) } @@ -88,11 +241,10 @@ impl Runtime { while !self.is_empty() && n < max { let task = self.next_task(); n += 1; + let task = self.0.owned.assert_owner(task); task.run(); } - self.0.maintenance(); - n } @@ -107,50 +259,21 @@ impl Runtime { fn shutdown(&self) { let mut core = self.0.core.try_lock().unwrap(); - for task in core.tasks.iter() { - task.shutdown(); - } + self.0.owned.close_and_shutdown_all(); while let Some(task) = core.queue.pop_back() { - task.shutdown(); + drop(task); } drop(core); - while !self.0.core.try_lock().unwrap().tasks.is_empty() { - self.0.maintenance(); - } - } -} - -impl Inner { - fn maintenance(&self) { - use std::mem::ManuallyDrop; - - for task in self.released.drain() { - let task = ManuallyDrop::new(task); - - // safety: see worker.rs - unsafe { - let ptr = task.header().into(); - self.core.try_lock().unwrap().tasks.remove(ptr); - } - } + assert!(self.0.owned.is_empty()); } } impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - let rt = CURRENT.try_lock().unwrap().as_ref().unwrap().clone(); - rt.0.core.try_lock().unwrap().tasks.push_front(task); - rt - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - // safety: copying worker.rs - let task = unsafe { Task::from_raw(task.header().into()) }; - self.0.released.push(task); - None + self.0.owned.remove(task) } fn schedule(&self, task: task::Notified<Self>) { diff --git a/src/runtime/tests/task_combinations.rs b/src/runtime/tests/task_combinations.rs new file mode 100644 index 0000000..76ce233 --- /dev/null +++ b/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option<crate::task::LocalSet>, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on<T>(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn<T>(&self, task: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option<oneshot::Sender<()>>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper<F> { + inner: F, + on_drop: Option<oneshot::Sender<()>>, + panic_on_drop: bool, + } + impl<F: Future> Future for FutWrapper<F> { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl<F> Drop for FutWrapper<F> { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option<oneshot::Sender<()>>, + wait_complete: Option<oneshot::Receiver<()>>, + on_output_drop: Option<oneshot::Sender<()>>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} diff --git a/src/runtime/thread_pool/idle.rs b/src/runtime/thread_pool/idle.rs index b77cce5..2cac30e 100644 --- a/src/runtime/thread_pool/idle.rs +++ b/src/runtime/thread_pool/idle.rs @@ -42,11 +42,11 @@ impl Idle { /// worker currently sleeping. pub(super) fn worker_to_notify(&self) -> Option<usize> { // If at least one worker is spinning, work being notified will - // eventully be found. A searching thread will find **some** work and + // eventually be found. A searching thread will find **some** work and // notify another worker, eventually leading to our work being found. // // For this to happen, this load must happen before the thread - // transitioning `num_searching` to zero. Acquire / Relese does not + // transitioning `num_searching` to zero. Acquire / Release does not // provide sufficient guarantees, so this load is done with `SeqCst` and // will pair with the `fetch_sub(1)` when transitioning out of // searching. diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs index 96312d3..3808aa2 100644 --- a/src/runtime/thread_pool/mod.rs +++ b/src/runtime/thread_pool/mod.rs @@ -12,7 +12,7 @@ pub(crate) use worker::Launch; pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; -use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::task::JoinHandle; use crate::runtime::Parker; use std::fmt; @@ -30,7 +30,7 @@ pub(crate) struct ThreadPool { /// /// The `Spawner` handle is *only* used for spawning new futures. It does not /// impact the lifecycle of the thread pool in any way. The thread pool may -/// shutdown while there are outstanding `Spawner` instances. +/// shut down while there are outstanding `Spawner` instances. /// /// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. /// @@ -93,15 +93,7 @@ impl Spawner { F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - - if let Err(task) = self.shared.schedule(task, false) { - // The newly spawned task could not be scheduled because the runtime - // is shutting down. The task must be explicitly shutdown at this point. - task.shutdown(); - } - - handle + worker::Shared::bind_new_task(&self.shared, future) } pub(crate) fn shutdown(&mut self) { diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs index 70cbddb..f5004c0 100644 --- a/src/runtime/thread_pool/worker.rs +++ b/src/runtime/thread_pool/worker.rs @@ -3,17 +3,70 @@ //! run queue and other state. When `block_in_place` is called, the worker's //! "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. +//! +//! # Shutdown +//! +//! Shutting down the runtime involves the following steps: +//! +//! 1. The Shared::close method is called. This closes the inject queue and +//! OwnedTasks instance and wakes up all worker threads. +//! +//! 2. Each worker thread observes the close signal next time it runs +//! Core::maintenance by checking whether the inject queue is closed. +//! The Core::is_shutdown flag is set to true. +//! +//! 3. The worker thread calls `pre_shutdown` in parallel. Here, the worker +//! will keep removing tasks from OwnedTasks until it is empty. No new +//! tasks can be pushed to the OwnedTasks during or after this step as it +//! was closed in step 1. +//! +//! 5. The workers call Shared::shutdown to enter the single-threaded phase of +//! shutdown. These calls will push their core to Shared::shutdown_cores, +//! and the last thread to push its core will finish the shutdown procedure. +//! +//! 6. The local run queue of each core is emptied, then the inject queue is +//! emptied. +//! +//! At this point, shutdown has completed. It is not possible for any of the +//! collections to contain any tasks at this point, as each collection was +//! closed first, then emptied afterwards. +//! +//! ## Spawns during shutdown +//! +//! When spawning tasks during shutdown, there are two cases: +//! +//! * The spawner observes the OwnedTasks being open, and the inject queue is +//! closed. +//! * The spawner observes the OwnedTasks being closed and doesn't check the +//! inject queue. +//! +//! The first case can only happen if the OwnedTasks::bind call happens before +//! or during step 1 of shutdown. In this case, the runtime will clean up the +//! task in step 3 of shutdown. +//! +//! In the latter case, the task was not spawned and the task is immediately +//! cancelled by the spawner. +//! +//! The correctness of shutdown requires both the inject queue and OwnedTasks +//! collection to have a closed bit. With a close bit on only the inject queue, +//! spawning could run in to a situation where a task is successfully bound long +//! after the runtime has shut down. With a close bit on only the OwnedTasks, +//! the first spawning situation could result in the notification being pushed +//! to the inject queue after step 6 of shutdown, which would leave a task in +//! the inject queue indefinitely. This would be a ref-count cycle and a memory +//! leak. use crate::coop; +use crate::future::Future; use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime; use crate::runtime::enter::EnterContext; use crate::runtime::park::{Parker, Unparker}; +use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; use crate::runtime::thread_pool::{AtomicCell, Idle}; use crate::runtime::{queue, task}; -use crate::util::linked_list::{Link, LinkedList}; use crate::util::FastRand; use std::cell::RefCell; @@ -44,7 +97,7 @@ struct Core { lifo_slot: Option<Notified>, /// The worker-local run queue. - run_queue: queue::Local<Arc<Worker>>, + run_queue: queue::Local<Arc<Shared>>, /// True if the worker is currently searching for more work. Searching /// involves attempting to steal from other workers. @@ -53,9 +106,6 @@ struct Core { /// True if the scheduler is being shutdown is_shutdown: bool, - /// Tasks owned by the core - tasks: LinkedList<Task, <Task as Link>::Target>, - /// Parker /// /// Stored in an `Option` as the parker is added / removed to make the @@ -73,11 +123,14 @@ pub(super) struct Shared { remotes: Box<[Remote]>, /// Submit work to the scheduler while **not** currently on a worker thread. - inject: queue::Inject<Arc<Worker>>, + inject: Inject<Arc<Shared>>, /// Coordinates idle workers idle: Idle, + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + /// Cores that have observed the shutdown signal /// /// The core is **not** placed back in the worker to avoid it from being @@ -89,11 +142,7 @@ pub(super) struct Shared { /// Used to communicate with a worker from other threads. struct Remote { /// Steal tasks from this worker. - steal: queue::Steal<Arc<Worker>>, - - /// Transfers tasks to be released. Any worker pushes tasks, only the owning - /// worker pops. - pending_drop: task::TransferStack<Arc<Worker>>, + steal: queue::Steal<Arc<Shared>>, /// Unparks the associated worker thread unpark: Unparker, @@ -117,10 +166,10 @@ pub(crate) struct Launch(Vec<Arc<Worker>>); type RunResult = Result<Box<Core>, ()>; /// A task handle -type Task = task::Task<Arc<Worker>>; +type Task = task::Task<Arc<Shared>>; /// A notified task handle -type Notified = task::Notified<Arc<Worker>>; +type Notified = task::Notified<Arc<Shared>>; // Tracks thread-local state scoped_thread_local!(static CURRENT: Context); @@ -142,22 +191,18 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) { run_queue, is_searching: false, is_shutdown: false, - tasks: LinkedList::new(), park: Some(park), rand: FastRand::new(seed()), })); - remotes.push(Remote { - steal, - pending_drop: task::TransferStack::new(), - unpark, - }); + remotes.push(Remote { steal, unpark }); } let shared = Arc::new(Shared { remotes: remotes.into_boxed_slice(), - inject: queue::Inject::new(), + inject: Inject::new(), idle: Idle::new(size), + owned: OwnedTasks::new(), shutdown_cores: Mutex::new(vec![]), }); @@ -203,18 +248,20 @@ where CURRENT.with(|maybe_cx| { match (crate::runtime::enter::context(), maybe_cx.is_some()) { (EnterContext::Entered { .. }, true) => { - // We are on a thread pool runtime thread, so we just need to set up blocking. + // We are on a thread pool runtime thread, so we just need to + // set up blocking. had_entered = true; } (EnterContext::Entered { allow_blocking }, false) => { - // We are on an executor, but _not_ on the thread pool. - // That is _only_ okay if we are in a thread pool runtime's block_on method: + // We are on an executor, but _not_ on the thread pool. That is + // _only_ okay if we are in a thread pool runtime's block_on + // method: if allow_blocking { had_entered = true; return; } else { - // This probably means we are on the basic_scheduler or in a LocalSet, - // where it is _not_ okay to block. + // This probably means we are on the basic_scheduler or in a + // LocalSet, where it is _not_ okay to block. panic!("can call blocking only when running on the multi-threaded runtime"); } } @@ -337,6 +384,8 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult { + let task = self.worker.shared.owned.assert_owner(task); + // Make sure the worker is not in the **searching** state. This enables // another idle worker to try to steal work. core.transition_from_searching(&self.worker); @@ -367,6 +416,7 @@ impl Context { if coop::has_budget_remaining() { // Run the LIFO task, then loop *self.core.borrow_mut() = Some(core); + let task = self.worker.shared.owned.assert_owner(task); task.run(); } else { // Not enough budget left to run the LIFO task, push it to @@ -538,42 +588,23 @@ impl Core { true } - /// Runs maintenance work such as free pending tasks and check the pool's - /// state. + /// Runs maintenance work such as checking the pool's state. fn maintenance(&mut self, worker: &Worker) { - self.drain_pending_drop(worker); - if !self.is_shutdown { // Check if the scheduler has been shutdown self.is_shutdown = worker.inject().is_closed(); } } - // Signals all tasks to shut down, and waits for them to complete. Must run - // before we enter the single-threaded phase of shutdown processing. + /// Signals all tasks to shut down, and waits for them to complete. Must run + /// before we enter the single-threaded phase of shutdown processing. fn pre_shutdown(&mut self, worker: &Worker) { // Signal to all tasks to shut down. - for header in self.tasks.iter() { - header.shutdown(); - } - - loop { - self.drain_pending_drop(worker); - - if self.tasks.is_empty() { - break; - } - - // Wait until signalled - let park = self.park.as_mut().expect("park missing"); - park.park().expect("park failed"); - } + worker.shared.owned.close_and_shutdown_all(); } - // Shutdown the core + /// Shutdown the core fn shutdown(&mut self) { - assert!(self.tasks.is_empty()); - // Take the core let mut park = self.park.take().expect("park missing"); @@ -582,149 +613,45 @@ impl Core { park.shutdown(); } - - fn drain_pending_drop(&mut self, worker: &Worker) { - use std::mem::ManuallyDrop; - - for task in worker.remote().pending_drop.drain() { - let task = ManuallyDrop::new(task); - - // safety: tasks are only pushed into the `pending_drop` stacks that - // are associated with the list they are inserted into. When a task - // is pushed into `pending_drop`, the ref-inc is skipped, so we must - // not ref-dec here. - // - // See `bind` and `release` implementations. - unsafe { - self.tasks.remove(task.header().into()); - } - } - } } impl Worker { /// Returns a reference to the scheduler's injection queue - fn inject(&self) -> &queue::Inject<Arc<Worker>> { + fn inject(&self) -> &Inject<Arc<Shared>> { &self.shared.inject } - - /// Return a reference to this worker's remote data - fn remote(&self) -> &Remote { - &self.shared.remotes[self.index] - } - - fn eq(&self, other: &Worker) -> bool { - self.shared.ptr_eq(&other.shared) && self.index == other.index - } } -impl task::Schedule for Arc<Worker> { - fn bind(task: Task) -> Arc<Worker> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - - // Track the task - cx.core - .borrow_mut() - .as_mut() - .expect("scheduler core missing") - .tasks - .push_front(task); - - // Return a clone of the worker - cx.worker.clone() - }) - } - +impl task::Schedule for Arc<Shared> { fn release(&self, task: &Task) -> Option<Task> { - use std::ptr::NonNull; - - enum Immediate { - // Task has been synchronously removed from the Core owned by the - // current thread - Removed(Option<Task>), - // Task is owned by another thread, so we need to notify it to clean - // up the task later. - MaybeRemote, - } - - let immediate = CURRENT.with(|maybe_cx| { - let cx = match maybe_cx { - Some(cx) => cx, - None => return Immediate::MaybeRemote, - }; - - if !self.eq(&cx.worker) { - // Task owned by another core, so we need to notify it. - return Immediate::MaybeRemote; - } - - let mut maybe_core = cx.core.borrow_mut(); - - if let Some(core) = &mut *maybe_core { - // Directly remove the task - // - // safety: the task is inserted in the list in `bind`. - unsafe { - let ptr = NonNull::from(task.header()); - return Immediate::Removed(core.tasks.remove(ptr)); - } - } - - Immediate::MaybeRemote - }); - - // Checks if we were called from within a worker, allowing for immediate - // removal of a scheduled task. Else we have to go through the slower - // process below where we remotely mark a task as dropped. - match immediate { - Immediate::Removed(task) => return task, - Immediate::MaybeRemote => (), - }; - - // Track the task to be released by the worker that owns it - // - // Safety: We get a new handle without incrementing the ref-count. - // A ref-count is held by the "owned" linked list and it is only - // ever removed from that list as part of the release process: this - // method or popping the task from `pending_drop`. Thus, we can rely - // on the ref-count held by the linked-list to keep the memory - // alive. - // - // When the task is removed from the stack, it is forgotten instead - // of dropped. - let task = unsafe { Task::from_raw(task.header().into()) }; - - self.remote().pending_drop.push(task); - - // The worker core has been handed off to another thread. In the - // event that the scheduler is currently shutting down, the thread - // that owns the task may be waiting on the release to complete - // shutdown. - if self.inject().is_closed() { - self.remote().unpark.unpark(); - } - - None + self.owned.remove(task) } fn schedule(&self, task: Notified) { - // Because this is not a newly spawned task, if scheduling fails due to - // the runtime shutting down, there is no special work that must happen - // here. - let _ = self.shared.schedule(task, false); + (**self).schedule(task, false); } fn yield_now(&self, task: Notified) { - // Because this is not a newly spawned task, if scheduling fails due to - // the runtime shutting down, there is no special work that must happen - // here. - let _ = self.shared.schedule(task, true); + (**self).schedule(task, true); } } impl Shared { - pub(super) fn schedule(&self, task: Notified, is_yield: bool) -> Result<(), Notified> { + pub(super) fn bind_new_task<T>(me: &Arc<Self>, future: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (handle, notified) = me.owned.bind(future, me.clone()); + + if let Some(notified) = notified { + me.schedule(notified, false); + } + + handle + } + + pub(super) fn schedule(&self, task: Notified, is_yield: bool) { CURRENT.with(|maybe_cx| { if let Some(cx) = maybe_cx { // Make sure the task is part of the **current** scheduler. @@ -732,15 +659,14 @@ impl Shared { // And the current thread still holds a core if let Some(core) = cx.core.borrow_mut().as_mut() { self.schedule_local(core, task, is_yield); - return Ok(()); + return; } } } - // Otherwise, use the inject queue - self.inject.push(task)?; + // Otherwise, use the inject queue. + self.inject.push(task); self.notify_parked(); - Ok(()) }) } @@ -825,13 +751,17 @@ impl Shared { return; } + debug_assert!(self.owned.is_empty()); + for mut core in cores.drain(..) { core.shutdown(); } // Drain the injection queue + // + // We already shut down every task, so we can simply drop the tasks. while let Some(task) = self.inject.pop() { - task.shutdown(); + drop(task); } } diff --git a/src/signal/registry.rs b/src/signal/registry.rs index 8b89108..e0a2df9 100644 --- a/src/signal/registry.rs +++ b/src/signal/registry.rs @@ -240,17 +240,17 @@ mod tests { let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]); registry.record_event(0); - assert_eq!(false, registry.broadcast()); + assert!(!registry.broadcast()); let first = registry.register_listener(0); let second = registry.register_listener(1); registry.record_event(0); - assert_eq!(true, registry.broadcast()); + assert!(registry.broadcast()); drop(first); registry.record_event(0); - assert_eq!(false, registry.broadcast()); + assert!(!registry.broadcast()); drop(second); } diff --git a/src/signal/unix.rs b/src/signal/unix.rs index f96b2f4..86ea9a9 100644 --- a/src/signal/unix.rs +++ b/src/signal/unix.rs @@ -4,6 +4,7 @@ //! `Signal` type for receiving notifications of signals. #![cfg(unix)] +#![cfg_attr(docsrs, doc(cfg(all(unix, feature = "signal"))))] use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage}; use crate::signal::RxFuture; diff --git a/src/signal/unix/driver.rs b/src/signal/unix/driver.rs index 315f3bd..5fe7c35 100644 --- a/src/signal/unix/driver.rs +++ b/src/signal/unix/driver.rs @@ -47,7 +47,7 @@ impl Driver { use std::mem::ManuallyDrop; use std::os::unix::io::{AsRawFd, FromRawFd}; - // NB: We give each driver a "fresh" reciever file descriptor to avoid + // NB: We give each driver a "fresh" receiver file descriptor to avoid // the issues described in alexcrichton/tokio-process#42. // // In the past we would reuse the actual receiver file descriptor and diff --git a/src/signal/windows.rs b/src/signal/windows.rs index c231d62..11ec6cb 100644 --- a/src/signal/windows.rs +++ b/src/signal/windows.rs @@ -5,127 +5,22 @@ //! `SetConsoleCtrlHandler` function which receives events of the type //! `CTRL_C_EVENT` and `CTRL_BREAK_EVENT`. -#![cfg(windows)] +#![cfg(any(windows, docsrs))] +#![cfg_attr(docsrs, doc(cfg(all(windows, feature = "signal"))))] -use crate::signal::registry::{globals, EventId, EventInfo, Init, Storage}; use crate::signal::RxFuture; - -use std::convert::TryFrom; use std::io; -use std::sync::Once; use std::task::{Context, Poll}; -use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; -use winapi::um::consoleapi::SetConsoleCtrlHandler; -use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; - -#[derive(Debug)] -pub(crate) struct OsStorage { - ctrl_c: EventInfo, - ctrl_break: EventInfo, -} - -impl Init for OsStorage { - fn init() -> Self { - Self { - ctrl_c: EventInfo::default(), - ctrl_break: EventInfo::default(), - } - } -} - -impl Storage for OsStorage { - fn event_info(&self, id: EventId) -> Option<&EventInfo> { - match DWORD::try_from(id) { - Ok(CTRL_C_EVENT) => Some(&self.ctrl_c), - Ok(CTRL_BREAK_EVENT) => Some(&self.ctrl_break), - _ => None, - } - } - - fn for_each<'a, F>(&'a self, mut f: F) - where - F: FnMut(&'a EventInfo), - { - f(&self.ctrl_c); - f(&self.ctrl_break); - } -} - -#[derive(Debug)] -pub(crate) struct OsExtraData {} -impl Init for OsExtraData { - fn init() -> Self { - Self {} - } -} - -/// Stream of events discovered via `SetConsoleCtrlHandler`. -/// -/// This structure can be used to listen for events of the type `CTRL_C_EVENT` -/// and `CTRL_BREAK_EVENT`. The `Stream` trait is implemented for this struct -/// and will resolve for each notification received by the process. Note that -/// there are few limitations with this as well: -/// -/// * A notification to this process notifies *all* `Event` streams for that -/// event type. -/// * Notifications to an `Event` stream **are coalesced** if they aren't -/// processed quickly enough. This means that if two notifications are -/// received back-to-back, then the stream may only receive one item about the -/// two notifications. -#[must_use = "streams do nothing unless polled"] -#[derive(Debug)] -pub(crate) struct Event { - inner: RxFuture, -} - -impl Event { - fn new(signum: DWORD) -> io::Result<Self> { - global_init()?; - - let rx = globals().register_listener(signum as EventId); - - Ok(Self { - inner: RxFuture::new(rx), - }) - } -} +#[cfg(not(docsrs))] +#[path = "windows/sys.rs"] +mod imp; +#[cfg(not(docsrs))] +pub(crate) use self::imp::{OsExtraData, OsStorage}; -fn global_init() -> io::Result<()> { - static INIT: Once = Once::new(); - - let mut init = None; - - INIT.call_once(|| unsafe { - let rc = SetConsoleCtrlHandler(Some(handler), TRUE); - let ret = if rc == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - }; - - init = Some(ret); - }); - - init.unwrap_or_else(|| Ok(())) -} - -unsafe extern "system" fn handler(ty: DWORD) -> BOOL { - let globals = globals(); - globals.record_event(ty as EventId); - - // According to https://docs.microsoft.com/en-us/windows/console/handlerroutine - // the handler routine is always invoked in a new thread, thus we don't - // have the same restrictions as in Unix signal handlers, meaning we can - // go ahead and perform the broadcast here. - if globals.broadcast() { - TRUE - } else { - // No one is listening for this notification any more - // let the OS fire the next (possibly the default) handler. - FALSE - } -} +#[cfg(docsrs)] +#[path = "windows/stub.rs"] +mod imp; /// Creates a new stream which receives "ctrl-c" notifications sent to the /// process. @@ -150,7 +45,9 @@ unsafe extern "system" fn handler(ty: DWORD) -> BOOL { /// } /// ``` pub fn ctrl_c() -> io::Result<CtrlC> { - Event::new(CTRL_C_EVENT).map(|inner| CtrlC { inner }) + Ok(CtrlC { + inner: self::imp::ctrl_c()?, + }) } /// Represents a stream which receives "ctrl-c" notifications sent to the process @@ -163,7 +60,7 @@ pub fn ctrl_c() -> io::Result<CtrlC> { #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct CtrlC { - inner: Event, + inner: RxFuture, } impl CtrlC { @@ -191,7 +88,7 @@ impl CtrlC { /// } /// ``` pub async fn recv(&mut self) -> Option<()> { - self.inner.inner.recv().await + self.inner.recv().await } /// Polls to receive the next signal notification event, outside of an @@ -223,7 +120,7 @@ impl CtrlC { /// } /// ``` pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.inner.inner.poll_recv(cx) + self.inner.poll_recv(cx) } } @@ -237,7 +134,7 @@ impl CtrlC { #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct CtrlBreak { - inner: Event, + inner: RxFuture, } impl CtrlBreak { @@ -263,7 +160,7 @@ impl CtrlBreak { /// } /// ``` pub async fn recv(&mut self) -> Option<()> { - self.inner.inner.recv().await + self.inner.recv().await } /// Polls to receive the next signal notification event, outside of an @@ -295,7 +192,7 @@ impl CtrlBreak { /// } /// ``` pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.inner.inner.poll_recv(cx) + self.inner.poll_recv(cx) } } @@ -320,56 +217,7 @@ impl CtrlBreak { /// } /// ``` pub fn ctrl_break() -> io::Result<CtrlBreak> { - Event::new(CTRL_BREAK_EVENT).map(|inner| CtrlBreak { inner }) -} - -#[cfg(all(test, not(loom)))] -mod tests { - use super::*; - use crate::runtime::Runtime; - - use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; - - #[test] - fn ctrl_c() { - let rt = rt(); - let _enter = rt.enter(); - - let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); - - assert_pending!(ctrl_c.poll()); - - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_C_EVENT); - } - - assert_ready_ok!(ctrl_c.poll()); - } - - #[test] - fn ctrl_break() { - let rt = rt(); - - rt.block_on(async { - let mut ctrl_break = assert_ok!(super::ctrl_break()); - - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_BREAK_EVENT); - } - - ctrl_break.recv().await.unwrap(); - }); - } - - fn rt() -> Runtime { - crate::runtime::Builder::new_current_thread() - .build() - .unwrap() - } + Ok(CtrlBreak { + inner: self::imp::ctrl_break()?, + }) } diff --git a/src/signal/windows/stub.rs b/src/signal/windows/stub.rs new file mode 100644 index 0000000..8863054 --- /dev/null +++ b/src/signal/windows/stub.rs @@ -0,0 +1,13 @@ +//! Stub implementations for the platform API so that rustdoc can build linkable +//! documentation on non-windows platforms. + +use crate::signal::RxFuture; +use std::io; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + panic!() +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + panic!() +} diff --git a/src/signal/windows/sys.rs b/src/signal/windows/sys.rs new file mode 100644 index 0000000..8d29c35 --- /dev/null +++ b/src/signal/windows/sys.rs @@ -0,0 +1,153 @@ +use std::convert::TryFrom; +use std::io; +use std::sync::Once; + +use crate::signal::registry::{globals, EventId, EventInfo, Init, Storage}; +use crate::signal::RxFuture; + +use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; +use winapi::um::consoleapi::SetConsoleCtrlHandler; +use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + new(CTRL_C_EVENT) +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + new(CTRL_BREAK_EVENT) +} + +fn new(signum: DWORD) -> io::Result<RxFuture> { + global_init()?; + let rx = globals().register_listener(signum as EventId); + Ok(RxFuture::new(rx)) +} + +#[derive(Debug)] +pub(crate) struct OsStorage { + ctrl_c: EventInfo, + ctrl_break: EventInfo, +} + +impl Init for OsStorage { + fn init() -> Self { + Self { + ctrl_c: EventInfo::default(), + ctrl_break: EventInfo::default(), + } + } +} + +impl Storage for OsStorage { + fn event_info(&self, id: EventId) -> Option<&EventInfo> { + match DWORD::try_from(id) { + Ok(CTRL_C_EVENT) => Some(&self.ctrl_c), + Ok(CTRL_BREAK_EVENT) => Some(&self.ctrl_break), + _ => None, + } + } + + fn for_each<'a, F>(&'a self, mut f: F) + where + F: FnMut(&'a EventInfo), + { + f(&self.ctrl_c); + f(&self.ctrl_break); + } +} + +#[derive(Debug)] +pub(crate) struct OsExtraData {} + +impl Init for OsExtraData { + fn init() -> Self { + Self {} + } +} + +fn global_init() -> io::Result<()> { + static INIT: Once = Once::new(); + + let mut init = None; + + INIT.call_once(|| unsafe { + let rc = SetConsoleCtrlHandler(Some(handler), TRUE); + let ret = if rc == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + }; + + init = Some(ret); + }); + + init.unwrap_or_else(|| Ok(())) +} + +unsafe extern "system" fn handler(ty: DWORD) -> BOOL { + let globals = globals(); + globals.record_event(ty as EventId); + + // According to https://docs.microsoft.com/en-us/windows/console/handlerroutine + // the handler routine is always invoked in a new thread, thus we don't + // have the same restrictions as in Unix signal handlers, meaning we can + // go ahead and perform the broadcast here. + if globals.broadcast() { + TRUE + } else { + // No one is listening for this notification any more + // let the OS fire the next (possibly the default) handler. + FALSE + } +} + +#[cfg(all(test, not(loom)))] +mod tests { + use super::*; + use crate::runtime::Runtime; + + use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; + + #[test] + fn ctrl_c() { + let rt = rt(); + let _enter = rt.enter(); + + let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); + + assert_pending!(ctrl_c.poll()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_C_EVENT); + } + + assert_ready_ok!(ctrl_c.poll()); + } + + #[test] + fn ctrl_break() { + let rt = rt(); + + rt.block_on(async { + let mut ctrl_break = assert_ok!(crate::signal::windows::ctrl_break()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_BREAK_EVENT); + } + + ctrl_break.recv().await.unwrap(); + }); + } + + fn rt() -> Runtime { + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + } +} diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs index e3c95f6..0e39dac 100644 --- a/src/sync/barrier.rs +++ b/src/sync/barrier.rs @@ -1,7 +1,6 @@ +use crate::loom::sync::Mutex; use crate::sync::watch; -use std::sync::Mutex; - /// A barrier enables multiple tasks to synchronize the beginning of some computation. /// /// ``` @@ -94,7 +93,7 @@ impl Barrier { // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across // a yield point, and thus marks the returned future as !Send. let generation = { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); let generation = state.generation; state.arrived += 1; if state.arrived == self.n { diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index a0bf5ef..872d53e 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -478,7 +478,7 @@ impl<'a> Acquire<'a> { let this = self.get_unchecked_mut(); ( Pin::new_unchecked(&mut this.node), - &this.semaphore, + this.semaphore, this.num_permits, &mut this.queued, ) diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index 3ef8f84..a2ca445 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -824,6 +824,13 @@ impl<T: Clone> Receiver<T> { /// the channel. A subsequent call to [`recv`] will return this value /// **unless** it has been since overwritten. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// /// [`Receiver`]: crate::sync::broadcast::Receiver /// [`recv`]: crate::sync::broadcast::Receiver::recv /// diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs index 1c9ab14..7a0873b 100644 --- a/src/sync/mpsc/block.rs +++ b/src/sync/mpsc/block.rs @@ -1,6 +1,5 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; use std::mem::MaybeUninit; use std::ops; @@ -344,8 +343,13 @@ impl<T> Block<T> { Err(curr) => curr, }; - // When running outside of loom, this calls `spin_loop_hint`. - thread::yield_now(); + #[cfg(all(test, loom))] + crate::loom::thread::yield_now(); + + // TODO: once we bump MSRV to 1.49+, use `hint::spin_loop` instead. + #[cfg(not(all(test, loom)))] + #[allow(deprecated)] + std::sync::atomic::spin_loop_hint(); } } } diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index cfd8da0..d7af172 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -134,11 +134,16 @@ impl<T> Receiver<T> { /// /// If there are no messages in the channel's buffer, but the channel has /// not yet been closed, this method will sleep until a message is sent or - /// the channel is closed. + /// the channel is closed. Note that if [`close`] is called, but there are + /// still outstanding [`Permits`] from before it was closed, the channel is + /// not considered closed by `recv` until the permits are released. /// - /// Note that if [`close`] is called, but there are still outstanding - /// [`Permits`] from before it was closed, the channel is not considered - /// closed by `recv` until the permits are released. + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. /// /// [`close`]: Self::close /// [`Permits`]: struct@crate::sync::mpsc::Permit @@ -335,6 +340,16 @@ impl<T> Sender<T> { /// [`close`]: Receiver::close /// [`Receiver`]: Receiver /// + /// # Cancel safety + /// + /// If `send` is used as the event in a [`tokio::select!`](crate::select) + /// statement and some other branch completes first, then it is guaranteed + /// that the message was not sent. + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `send` makes you lose your place in the queue. + /// /// # Examples /// /// In the following example, each call to `send` will block until the @@ -376,6 +391,11 @@ impl<T> Sender<T> { /// This allows the producers to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -617,6 +637,12 @@ impl<T> Sender<T> { /// [`Permit`]: Permit /// [`send`]: Permit::send /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve` makes you lose your place in the queue. + /// /// # Examples /// /// ``` @@ -666,6 +692,12 @@ impl<T> Sender<T> { /// Dropping the [`OwnedPermit`] without sending a message releases the /// capacity back to the channel. /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve_owned` makes you lose your place in the queue. + /// /// # Examples /// Sending a message using an [`OwnedPermit`]: /// ``` diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index ffdb34c..23c80f6 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -82,6 +82,13 @@ impl<T> UnboundedReceiver<T> { /// `None` is returned when all `Sender` halves have dropped, indicating /// that no further values can be sent on the channel. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// /// # Examples /// /// ``` @@ -241,6 +248,11 @@ impl<T> UnboundedSender<T> { /// This allows the producers to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -270,6 +282,7 @@ impl<T> UnboundedSender<T> { pub async fn closed(&self) { self.chan.closed().await } + /// Checks if the channel has been closed. This happens when the /// [`UnboundedReceiver`] is dropped, or when the /// [`UnboundedReceiver::close`] method is called. diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index 9fd7c91..6acd28b 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -273,9 +273,15 @@ impl<T: ?Sized> Mutex<T> { } } - /// Locks this mutex, causing the current task - /// to yield until the lock has been acquired. - /// When the lock has been acquired, function returns a [`MutexGuard`]. + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock` makes you lose your place in + /// the queue. /// /// # Examples /// @@ -305,6 +311,12 @@ impl<T: ?Sized> Mutex<T> { /// method, and the guard will live for the `'static` lifetime, as it keeps /// the `Mutex` alive by holding an `Arc`. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -561,6 +573,32 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { marker: marker::PhantomData, }) } + + /// Returns a reference to the original `Mutex`. + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// async fn unlock_and_relock<'l>(guard: MutexGuard<'l, u32>) -> MutexGuard<'l, u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex = MutexGuard::mutex(&guard); + /// drop(guard); + /// let guard = mutex.lock().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Mutex::new(0u32); + /// # let guard = mutex.lock().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &'a Mutex<T> { + this.lock + } } impl<T: ?Sized> Drop for MutexGuard<'_, T> { @@ -596,6 +634,35 @@ impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> { // === impl OwnedMutexGuard === +impl<T: ?Sized> OwnedMutexGuard<T> { + /// Returns a reference to the original `Arc<Mutex>`. + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// + /// async fn unlock_and_relock(guard: OwnedMutexGuard<u32>) -> OwnedMutexGuard<u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex: Arc<Mutex<u32>> = OwnedMutexGuard::mutex(&guard).clone(); + /// drop(guard); + /// let guard = mutex.lock_owned().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Arc::new(Mutex::new(0u32)); + /// # let guard = mutex.lock_owned().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &Arc<Mutex<T>> { + &this.lock + } +} + impl<T: ?Sized> Drop for OwnedMutexGuard<T> { fn drop(&mut self) { self.lock.s.release(1) diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 07be759..2ea6359 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -246,6 +246,12 @@ impl Notify { /// /// [`notify_one()`]: Notify::notify_one /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute notifications in the order + /// they were requested. Cancelling a call to `notified` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -522,7 +528,7 @@ impl Notified<'_> { is_unpin::<AtomicUsize>(); let me = self.get_unchecked_mut(); - (&me.notify, &mut me.state, &me.waiter) + (me.notify, &mut me.state, &me.waiter) } } } diff --git a/src/sync/once_cell.rs b/src/sync/once_cell.rs index fa9b1f1..91705a5 100644 --- a/src/sync/once_cell.rs +++ b/src/sync/once_cell.rs @@ -1,4 +1,4 @@ -use super::Semaphore; +use super::{Semaphore, SemaphorePermit, TryAcquireError}; use crate::loom::cell::UnsafeCell; use std::error::Error; use std::fmt; @@ -8,15 +8,30 @@ use std::ops::Drop; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; -/// A thread-safe cell which can be written to only once. +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. /// -/// Provides the functionality to either set the value, in case `OnceCell` -/// is uninitialized, or get the already initialized value by using an async -/// function via [`OnceCell::get_or_init`]. -/// -/// [`OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init +/// A `OnceCell` is typically used for global variables that need to be +/// initialized once on first use, but need no further changes. The `OnceCell` +/// in Tokio allows the initialization procedure to be asynchronous. /// /// # Examples +/// /// ``` /// use tokio::sync::OnceCell; /// @@ -28,8 +43,28 @@ use std::sync::atomic::{AtomicBool, Ordering}; /// /// #[tokio::main] /// async fn main() { -/// let result1 = ONCE.get_or_init(some_computation).await; -/// assert_eq!(*result1, 2); +/// let result = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result, 2); +/// } +/// ``` +/// +/// It is often useful to write a wrapper method for accessing the value. +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// async fn get_global_integer() -> &'static u32 { +/// ONCE.get_or_init(|| async { +/// 1 + 1 +/// }).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = get_global_integer().await; +/// assert_eq!(*result, 2); /// } /// ``` pub struct OnceCell<T> { @@ -68,7 +103,7 @@ impl<T: Eq> Eq for OnceCell<T> {} impl<T> Drop for OnceCell<T> { fn drop(&mut self) { - if self.initialized() { + if self.initialized_mut() { unsafe { self.value .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); @@ -77,8 +112,20 @@ impl<T> Drop for OnceCell<T> { } } +impl<T> From<T> for OnceCell<T> { + fn from(value: T) -> Self { + let semaphore = Semaphore::new(0); + semaphore.close(); + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore, + } + } +} + impl<T> OnceCell<T> { - /// Creates a new uninitialized OnceCell instance. + /// Creates a new empty `OnceCell` instance. pub fn new() -> Self { OnceCell { value_set: AtomicBool::new(false), @@ -87,26 +134,44 @@ impl<T> OnceCell<T> { } } - /// Creates a new initialized OnceCell instance if `value` is `Some`, otherwise - /// has the same functionality as [`OnceCell::new`]. + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. /// /// [`OnceCell::new`]: crate::sync::OnceCell::new pub fn new_with(value: Option<T>) -> Self { if let Some(v) = value { - let semaphore = Semaphore::new(0); - semaphore.close(); - OnceCell { - value_set: AtomicBool::new(true), - value: UnsafeCell::new(MaybeUninit::new(v)), - semaphore, - } + OnceCell::from(v) } else { OnceCell::new() } } - /// Creates a new uninitialized OnceCell instance. - #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + /// Creates a new empty `OnceCell` instance. + /// + /// Equivalent to `OnceCell::new`, except that it can be used in static + /// variables. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::OnceCell; + /// + /// static ONCE: OnceCell<u32> = OnceCell::const_new(); + /// + /// async fn get_global_integer() -> &'static u32 { + /// ONCE.get_or_init(|| async { + /// 1 + 1 + /// }).await + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let result = get_global_integer().await; + /// assert_eq!(*result, 2); + /// } + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new() -> Self { OnceCell { @@ -116,33 +181,48 @@ impl<T> OnceCell<T> { } } - /// Whether the value of the OnceCell is set or not. + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. self.value_set.load(Ordering::Acquire) } - // SAFETY: safe to call only once self.initialized() is true + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked(&self) -> &T { &*self.value.with(|ptr| (*ptr).as_ptr()) } - // SAFETY: safe to call only once self.initialized() is true. Safe because - // because of the mutable reference. + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked_mut(&mut self) -> &mut T { &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) } - // SAFETY: safe to call only once a permit on the semaphore has been - // acquired - unsafe fn set_value(&self, value: T) { - self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. self.value_set.store(true, Ordering::Release); self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } } - /// Tries to get a reference to the value of the OnceCell. - /// - /// Returns None if the value of the OnceCell hasn't previously been initialized. + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. pub fn get(&self) -> Option<&T> { if self.initialized() { Some(unsafe { self.get_unchecked() }) @@ -151,179 +231,161 @@ impl<T> OnceCell<T> { } } - /// Tries to return a mutable reference to the value of the cell. + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. /// - /// Returns None if the cell hasn't previously been initialized. + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. pub fn get_mut(&mut self) -> Option<&mut T> { - if self.initialized() { + if self.initialized_mut() { Some(unsafe { self.get_unchecked_mut() }) } else { None } } - /// Sets the value of the OnceCell to the argument value. + /// Set the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. /// - /// If the value of the OnceCell was already set prior to this call - /// then [`SetError::AlreadyInitializedError`] is returned. If another thread - /// is initializing the cell while this method is called, - /// [`SetError::InitializingError`] is returned. In order to wait - /// for an ongoing initialization to finish, call - /// [`OnceCell::get_or_init`] instead. + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. /// /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError - /// ['OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init pub fn set(&self, value: T) -> Result<(), SetError<T>> { - if !self.initialized() { - // Another thread might be initializing the cell, in which case `try_acquire` will - // return an error - match self.semaphore.try_acquire() { - Ok(_permit) => { - if !self.initialized() { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - return Ok(()); - } else { - unreachable!( - "acquired the permit after OnceCell value was already initialized." - ); - } - } - _ => { - // Couldn't acquire the permit, look if initializing process is already completed - if !self.initialized() { - return Err(SetError::InitializingError(value)); - } - } - } + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); } - Err(SetError::AlreadyInitializedError(value)) + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. + /// Get the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_init<F, Fut>(&self, f: F) -> &T where F: FnOnce() -> Fut, Fut: Future<Output = T>, { if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { self.get_unchecked() } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has intialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!("acquired semaphore after value was already initialized."); - } + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } } } } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. - /// If the function argument `f` returns an error, `get_or_try_init` - /// returns that error, otherwise the result of `f` will be stored in the cell. + /// Get the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E> where F: FnOnce() -> Fut, Fut: Future<Output = Result<T, E>>, { if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { Ok(self.get_unchecked()) } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has intialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_try_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - match value { - Ok(value) => { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } - Err(e) => Err(e), - } - } else { - unreachable!("acquired semaphore after value was already initialized."); + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), } } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } } } } } - /// Moves the value out of the cell, destroying the cell in the process. - /// - /// Returns `None` if the cell is uninitialized. + /// Take the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. pub fn into_inner(mut self) -> Option<T> { - if self.initialized() { + if self.initialized_mut() { // Set to uninitialized for the destructor of `OnceCell` to work properly *self.value_set.get_mut() = false; Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) @@ -332,20 +394,18 @@ impl<T> OnceCell<T> { } } - /// Takes ownership of the current value, leaving the cell uninitialized. - /// - /// Returns `None` if the cell is uninitialized. + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. pub fn take(&mut self) -> Option<T> { std::mem::take(self).into_inner() } } -// Since `get` gives us access to immutable references of the -// OnceCell, OnceCell can only be Sync if T is Sync, otherwise -// OnceCell would allow sharing references of !Sync values across -// threads. We need T to be Send in order for OnceCell to by Sync -// because we can use `set` on `&OnceCell<T>` to send -// values (of type T) across threads. +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell<T>` to send values +// (of type T) across threads. unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // Access to OnceCell's value is guarded by the semaphore permit @@ -353,20 +413,17 @@ unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // it's safe to send it to another thread unsafe impl<T: Send> Send for OnceCell<T> {} -/// Errors that can be returned from [`OnceCell::set`] +/// Errors that can be returned from [`OnceCell::set`]. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set #[derive(Debug, PartialEq)] pub enum SetError<T> { - /// Error resulting from [`OnceCell::set`] calls if the cell was previously initialized. + /// The cell was already initialized when [`OnceCell::set`] was called. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set AlreadyInitializedError(T), - /// Error resulting from [`OnceCell::set`] calls when the cell is currently being - /// inintialized during the calls to that method. - /// - /// [`OnceCell::set`]: crate::sync::OnceCell::set + /// The cell is currently being initialized. InitializingError(T), } diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 6f0c011..120bc72 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -299,6 +299,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop this read access of the `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read` makes you lose your place in + /// the queue. + /// /// # Examples /// /// ``` @@ -357,6 +363,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop this read access of the `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -501,6 +513,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop the write access of this `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write` makes you lose your place + /// in the queue. + /// /// # Examples /// /// ``` @@ -543,6 +561,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop the write access of this `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index 5d42d1c..839b523 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -162,6 +162,12 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire` makes you lose your place + /// in the queue. + /// /// # Examples /// /// ``` @@ -187,7 +193,7 @@ impl Semaphore { pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { self.ll_sem.acquire(1).await?; Ok(SemaphorePermit { - sem: &self, + sem: self, permits: 1, }) } @@ -198,6 +204,12 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permits. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -217,7 +229,7 @@ impl Semaphore { pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { self.ll_sem.acquire(n).await?; Ok(SemaphorePermit { - sem: &self, + sem: self, permits: n, }) } @@ -302,6 +314,12 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -346,6 +364,12 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many_owned` makes you lose + /// your place in the queue. + /// /// # Examples /// /// ``` diff --git a/src/sync/watch.rs b/src/sync/watch.rs index 42d417a..96d1d16 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -56,7 +56,7 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst}; +use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::ops; @@ -74,7 +74,7 @@ pub struct Receiver<T> { shared: Arc<Shared<T>>, /// Last observed version - version: usize, + version: Version, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -104,7 +104,7 @@ struct Shared<T> { /// /// The lowest bit represents a "closed" state. The rest of the bits /// represent the current version. - version: AtomicUsize, + state: AtomicState, /// Tracks the number of `Receiver` instances ref_count_rx: AtomicUsize, @@ -152,7 +152,69 @@ pub mod error { impl std::error::Error for RecvError {} } -const CLOSED: usize = 1; +use self::state::{AtomicState, Version}; +mod state { + use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::Ordering::SeqCst; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, SeqCst); + } + } +} /// Creates a new watch channel, returning the "send" and "receive" handles. /// @@ -184,7 +246,7 @@ const CLOSED: usize = 1; pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(0), + state: AtomicState::new(), ref_count_rx: AtomicUsize::new(1), notify_rx: Notify::new(), notify_tx: Notify::new(), @@ -194,13 +256,16 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { shared: shared.clone(), }; - let rx = Receiver { shared, version: 0 }; + let rx = Receiver { + shared, + version: Version::initial(), + }; (tx, rx) } impl<T> Receiver<T> { - fn from_shared(version: usize, shared: Arc<Shared<T>>) -> Self { + fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self { // No synchronization necessary as this is only used as a counter and // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); @@ -208,12 +273,18 @@ impl<T> Receiver<T> { Self { shared, version } } - /// Returns a reference to the most recently sent value + /// Returns a reference to the most recently sent value. + /// + /// This method does not mark the returned value as seen, so future calls to + /// [`changed`] may return immediately even if you have already seen the + /// value with a call to `borrow`. /// /// Outstanding borrows hold a read lock. This means that long lived borrows /// could cause the send half to block. It is recommended to keep the borrow /// as short lived as possible. /// + /// [`changed`]: Receiver::changed + /// /// # Examples /// /// ``` @@ -227,11 +298,40 @@ impl<T> Receiver<T> { Ref { inner } } - /// Wait for a change notification + /// Returns a reference to the most recently sent value and mark that value + /// as seen. + /// + /// This method marks the value as seen, so [`changed`] will not return + /// immediately if the newest value is one previously returned by + /// `borrow_and_update`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// [`changed`]: Receiver::changed + pub fn borrow_and_update(&mut self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + self.version = self.shared.state.load().version(); + Ref { inner } + } + + /// Wait for a change notification, then mark the newest value as seen. + /// + /// If the newest value in the channel has not yet been marked seen when + /// this method is called, the method marks that value seen and returns + /// immediately. If the newest value has already been marked seen, then the + /// method sleeps until a new message is sent by the [`Sender`] connected to + /// this `Receiver`, or until the [`Sender`] is dropped. + /// + /// This method returns an error if and only if the [`Sender`] is dropped. + /// + /// # Cancel safety /// - /// Returns when a new value has been sent by the [`Sender`] since the last - /// time `changed()` was called. When the `Sender` half is dropped, `Err` is - /// returned. + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no values have been marked + /// seen by this call to `changed`. /// /// [`Sender`]: struct@Sender /// @@ -280,11 +380,11 @@ impl<T> Receiver<T> { fn maybe_changed<T>( shared: &Shared<T>, - version: &mut usize, + version: &mut Version, ) -> Option<Result<(), error::RecvError>> { // Load the version from the state - let state = shared.version.load(SeqCst); - let new_version = state & !CLOSED; + let state = shared.state.load(); + let new_version = state.version(); if *version != new_version { // Observe the new version and return @@ -292,7 +392,7 @@ fn maybe_changed<T>( return Some(Ok(())); } - if CLOSED == state & CLOSED { + if state.is_closed() { // All receivers have dropped. return Some(Err(error::RecvError(()))); } @@ -322,16 +422,29 @@ impl<T> Drop for Receiver<T> { impl<T> Sender<T> { /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel has been closed, which happens when + /// every receiver has been dropped. pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { // This is pretty much only useful as a hint anyway, so synchronization isn't critical. if 0 == self.shared.ref_count_rx.load(Relaxed) { return Err(error::SendError { inner: value }); } - *self.shared.value.write().unwrap() = value; + { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.write().unwrap(); + *lock = value; + + self.shared.state.increment_version(); - // Update the version. 2 is used so that the CLOSED bit is not set. - self.shared.version.fetch_add(2, SeqCst); + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + } // Notify all watchers self.shared.notify_rx.notify_waiters(); @@ -379,6 +492,11 @@ impl<T> Sender<T> { /// This allows the producer to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -412,7 +530,7 @@ impl<T> Sender<T> { cfg_signal_internal! { pub(crate) fn subscribe(&self) -> Receiver<T> { let shared = self.shared.clone(); - let version = shared.version.load(SeqCst); + let version = shared.state.load().version(); Receiver::from_shared(version, shared) } @@ -443,7 +561,7 @@ impl<T> Sender<T> { impl<T> Drop for Sender<T> { fn drop(&mut self) { - self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.state.set_closed(); self.shared.notify_rx.notify_waiters(); } } diff --git a/src/task/blocking.rs b/src/task/blocking.rs index e4fe254..806dbbd 100644 --- a/src/task/blocking.rs +++ b/src/task/blocking.rs @@ -89,13 +89,14 @@ cfg_rt! { /// /// Tokio will spawn more blocking threads when they are requested through this /// function until the upper limit configured on the [`Builder`] is reached. - /// This limit is very large by default, because `spawn_blocking` is often used - /// for various kinds of IO operations that cannot be performed asynchronously. - /// When you run CPU-bound code using `spawn_blocking`, you should keep this - /// large upper limit in mind. When running many CPU-bound computations, a - /// semaphore or some other synchronization primitive should be used to limit - /// the number of computation executed in parallel. Specialized CPU-bound - /// executors, such as [rayon], may also be a good fit. + /// After reaching the upper limit, the tasks are put in a queue. + /// The thread limit is very large by default, because `spawn_blocking` is often + /// used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind. When running many CPU-bound + /// computations, a semaphore or some other synchronization primitive should be + /// used to limit the number of computation executed in parallel. Specialized + /// CPU-bound executors, such as [rayon], may also be a good fit. /// /// This function is intended for non-async operations that eventually finish on /// their own. If you want to spawn an ordinary thread, you should use diff --git a/src/task/builder.rs b/src/task/builder.rs new file mode 100644 index 0000000..e46bdef --- /dev/null +++ b/src/task/builder.rs @@ -0,0 +1,105 @@ +#![allow(unreachable_pub)] +use crate::util::error::CONTEXT_MISSING_ERROR; +use crate::{runtime::context, task::JoinHandle}; +use std::future::Future; + +/// Factory which is used to configure the properties of a new task. +/// +/// Methods can be chained in order to configure it. +/// +/// Currently, there is only one configuration option: +/// +/// - [`name`], which specifies an associated name for +/// the task +/// +/// There are three types of task that can be spawned from a Builder: +/// - [`spawn_local`] for executing futures on the current thread +/// - [`spawn`] for executing [`Send`] futures on the runtime +/// - [`spawn_blocking`] for executing blocking code in the +/// blocking thread pool. +/// +/// ## Example +/// +/// ```no_run +/// use tokio::net::{TcpListener, TcpStream}; +/// +/// use std::io; +/// +/// async fn process(socket: TcpStream) { +/// // ... +/// # drop(socket); +/// } +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let listener = TcpListener::bind("127.0.0.1:8080").await?; +/// +/// loop { +/// let (socket, _) = listener.accept().await?; +/// +/// tokio::task::Builder::new() +/// .name("tcp connection handler") +/// .spawn(async move { +/// // Process each socket concurrently. +/// process(socket).await +/// }); +/// } +/// } +/// ``` +#[derive(Default, Debug)] +pub struct Builder<'a> { + name: Option<&'a str>, +} + +impl<'a> Builder<'a> { + /// Creates a new task builder. + pub fn new() -> Self { + Self::default() + } + + /// Assigns a name to the task which will be spawned. + pub fn name(&self, name: &'a str) -> Self { + Self { name: Some(name) } + } + + /// Spawns a task on the executor. + /// + /// See [`task::spawn`](crate::task::spawn) for + /// more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + super::spawn::spawn_inner(future, self.name) + } + + /// Spawns a task on the current thread. + /// + /// See [`task::spawn_local`](crate::task::spawn_local) + /// for more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + 'static, + Fut::Output: 'static, + { + super::local::spawn_local_inner(future, self.name) + } + + /// Spawns blocking code on the blocking threadpool. + /// + /// See [`task::spawn_blocking`](crate::task::spawn_blocking) + /// for more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output> + where + Function: FnOnce() -> Output + Send + 'static, + Output: Send + 'static, + { + context::current() + .expect(CONTEXT_MISSING_ERROR) + .spawn_blocking_inner(function, self.name) + } +} diff --git a/src/task/local.rs b/src/task/local.rs index 64f1ac5..a28d793 100644 --- a/src/task/local.rs +++ b/src/task/local.rs @@ -1,15 +1,15 @@ //! Runs `!Send` futures on the current thread. -use crate::runtime::task::{self, JoinHandle, Task}; +use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; -use crate::util::linked_list::{Link, LinkedList}; +use crate::util::VecDequeCell; -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::collections::VecDeque; use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::sync::{Arc, Mutex}; use std::task::Poll; use pin_project_lite::pin_project; @@ -224,25 +224,20 @@ cfg_rt! { /// State available from the thread-local struct Context { - /// Owned task set and local run queue - tasks: RefCell<Tasks>, - - /// State shared between threads. - shared: Arc<Shared>, -} - -struct Tasks { /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, + owned: LocalOwnedTasks<Arc<Shared>>, /// Local run queue sender and receiver. - queue: VecDeque<task::Notified<Arc<Shared>>>, + queue: VecDequeCell<task::Notified<Arc<Shared>>>, + + /// State shared between threads. + shared: Arc<Shared>, } /// LocalSet state shared between threads. struct Shared { /// Remote run queue sender - queue: Mutex<VecDeque<task::Notified<Arc<Shared>>>>, + queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>, /// Wake the `LocalSet` task waker: AtomicWaker, @@ -297,15 +292,24 @@ cfg_rt! { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local"); + spawn_local_inner(future, None) + } + + pub(super) fn spawn_local_inner<F>(future: F, name: Option<&str>) -> JoinHandle<F::Output> + where F: Future + 'static, + F::Output: 'static + { + let future = crate::util::trace::task(future, "local", name); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); - // Safety: Tasks are only polled and dropped from the thread that - // spawns them. - let (task, handle) = unsafe { task::joinable_local(future) }; - cx.tasks.borrow_mut().queue.push_back(task); + let (handle, notified) = cx.owned.bind(future, cx.shared.clone()); + + if let Some(notified) = notified { + cx.shared.schedule(notified); + } + handle }) } @@ -326,12 +330,10 @@ impl LocalSet { LocalSet { tick: Cell::new(0), context: Context { - tasks: RefCell::new(Tasks { - owned: LinkedList::new(), - queue: VecDeque::with_capacity(INITIAL_CAPACITY), - }), + owned: LocalOwnedTasks::new(), + queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), }), }, @@ -381,9 +383,14 @@ impl LocalSet { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local"); - let (task, handle) = unsafe { task::joinable_local(future) }; - self.context.tasks.borrow_mut().queue.push_back(task); + let future = crate::util::trace::task(future, "local", None); + + let (handle, notified) = self.context.owned.bind(future, self.context.shared.clone()); + + if let Some(notified) = notified { + self.context.shared.schedule(notified); + } + self.context.shared.waker.wake(); handle } @@ -522,26 +529,30 @@ impl LocalSet { true } - fn next_task(&self) -> Option<task::Notified<Arc<Shared>>> { + fn next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>> { let tick = self.tick.get(); self.tick.set(tick.wrapping_add(1)); - if tick % REMOTE_FIRST_INTERVAL == 0 { + let task = if tick % REMOTE_FIRST_INTERVAL == 0 { self.context .shared .queue .lock() - .unwrap() - .pop_front() - .or_else(|| self.context.tasks.borrow_mut().queue.pop_front()) + .as_mut() + .and_then(|queue| queue.pop_front()) + .or_else(|| self.context.queue.pop_front()) } else { - self.context - .tasks - .borrow_mut() - .queue - .pop_front() - .or_else(|| self.context.shared.queue.lock().unwrap().pop_front()) - } + self.context.queue.pop_front().or_else(|| { + self.context + .shared + .queue + .lock() + .as_mut() + .and_then(|queue| queue.pop_front()) + }) + }; + + task.map(|task| self.context.owned.assert_owner(task)) } fn with<T>(&self, f: impl FnOnce() -> T) -> T { @@ -567,7 +578,7 @@ impl Future for LocalSet { // there are still tasks remaining in the run queue. cx.waker().wake_by_ref(); Poll::Pending - } else if self.context.tasks.borrow().owned.is_empty() { + } else if self.context.owned.is_empty() { // If the scheduler has no remaining futures, we're done! Poll::Ready(()) } else { @@ -588,27 +599,24 @@ impl Default for LocalSet { impl Drop for LocalSet { fn drop(&mut self) { self.with(|| { - // Loop required here to ensure borrow is dropped between iterations - #[allow(clippy::while_let_loop)] - loop { - let task = match self.context.tasks.borrow_mut().owned.pop_back() { - Some(task) => task, - None => break, - }; - - // Safety: same as `run_unchecked`. - task.shutdown(); - } - - for task in self.context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + // Shut down all tasks in the LocalOwnedTasks and close it to + // prevent new tasks from ever being added. + self.context.owned.close_and_shutdown_all(); + + // We already called shutdown on all tasks above, so there is no + // need to call shutdown. + for task in self.context.queue.take() { + drop(task); } - for task in self.context.shared.queue.lock().unwrap().drain(..) { - task.shutdown(); + // Take the queue from the Shared object to prevent pushing + // notifications to it in the future. + let queue = self.context.shared.queue.lock().take().unwrap(); + for task in queue { + drop(task); } - assert!(self.context.tasks.borrow().owned.is_empty()); + assert!(self.context.owned.is_empty()); }); } } @@ -651,11 +659,19 @@ impl Shared { fn schedule(&self, task: task::Notified<Arc<Self>>) { CURRENT.with(|maybe_cx| match maybe_cx { Some(cx) if cx.shared.ptr_eq(self) => { - cx.tasks.borrow_mut().queue.push_back(task); + cx.queue.push_back(task); } _ => { - self.queue.lock().unwrap().push_back(task); - self.waker.wake(); + // First check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); + } } }); } @@ -666,26 +682,11 @@ impl Shared { } impl task::Schedule for Arc<Shared> { - fn bind(task: Task<Self>) -> Arc<Shared> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.tasks.borrow_mut().owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - use std::ptr::NonNull; - CURRENT.with(|maybe_cx| { let cx = maybe_cx.expect("scheduler context missing"); - assert!(cx.shared.ptr_eq(self)); - - let ptr = NonNull::from(task.header()); - // safety: task must be contained by list. It is inserted into the - // list in `bind`. - unsafe { cx.tasks.borrow_mut().owned.remove(ptr) } + cx.owned.remove(task) }) } diff --git a/src/task/mod.rs b/src/task/mod.rs index 25dab0c..ea98787 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -299,4 +299,14 @@ cfg_rt! { mod unconstrained; pub use unconstrained::{unconstrained, Unconstrained}; + + cfg_trace! { + mod builder; + pub use builder::Builder; + } + + /// Task-related futures. + pub mod futures { + pub use super::task_local::TaskLocalFuture; + } } diff --git a/src/task/spawn.rs b/src/task/spawn.rs index d846fb4..065d38f 100644 --- a/src/task/spawn.rs +++ b/src/task/spawn.rs @@ -1,6 +1,4 @@ -use crate::runtime; -use crate::task::JoinHandle; -use crate::util::error::CONTEXT_MISSING_ERROR; +use crate::{task::JoinHandle, util::error::CONTEXT_MISSING_ERROR}; use std::future::Future; @@ -124,14 +122,28 @@ cfg_rt! { /// error[E0391]: cycle detected when processing `main` /// ``` #[cfg_attr(tokio_track_caller, track_caller)] - pub fn spawn<T>(task: T) -> JoinHandle<T::Output> + pub fn spawn<T>(future: T) -> JoinHandle<T::Output> where T: Future + Send + 'static, T::Output: Send + 'static, { - let spawn_handle = runtime::context::spawn_handle() - .expect(CONTEXT_MISSING_ERROR); - let task = crate::util::trace::task(task, "task"); + // preventing stack overflows on debug mode, by quickly sending the + // task to the heap. + if cfg!(debug_assertions) && std::mem::size_of::<T>() > 2048 { + spawn_inner(Box::pin(future), None) + } else { + spawn_inner(future, None) + } + } + + #[cfg_attr(tokio_track_caller, track_caller)] + pub(super) fn spawn_inner<T>(future: T, name: Option<&str>) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let spawn_handle = crate::runtime::context::spawn_handle().expect(CONTEXT_MISSING_ERROR); + let task = crate::util::trace::task(future, "task", name); spawn_handle.spawn(task) } } diff --git a/src/task/task_local.rs b/src/task/task_local.rs index 6571ffd..b6e7df4 100644 --- a/src/task/task_local.rs +++ b/src/task/task_local.rs @@ -2,6 +2,7 @@ use pin_project_lite::pin_project; use std::cell::RefCell; use std::error::Error; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, thread}; @@ -115,16 +116,16 @@ impl<T: 'static> LocalKey<T> { /// }).await; /// # } /// ``` - pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output + pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> where F: Future, { TaskLocalFuture { - local: &self, + local: self, slot: Some(value), future: f, + _pinned: PhantomPinned, } - .await } /// Sets a value `T` as the task-local value for the closure `F`. @@ -148,12 +149,14 @@ impl<T: 'static> LocalKey<T> { where F: FnOnce() -> R, { - let mut scope = TaskLocalFuture { - local: &self, + let scope = TaskLocalFuture { + local: self, slot: Some(value), future: (), + _pinned: PhantomPinned, }; - Pin::new(&mut scope).with_task(|_| f()) + crate::pin!(scope); + scope.with_task(|_| f()) } /// Accesses the current task-local and runs the provided closure. @@ -206,11 +209,37 @@ impl<T: 'static> fmt::Debug for LocalKey<T> { } pin_project! { - struct TaskLocalFuture<T: StaticLifetime, F> { + /// A future that sets a value `T` of a task local for the future `F` during + /// its execution. + /// + /// The value of the task-local must be `'static` and will be dropped on the + /// completion of the future. + /// + /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub struct TaskLocalFuture<T, F> + where + T: 'static + { local: &'static LocalKey<T>, slot: Option<T>, #[pin] future: F, + #[pin] + _pinned: PhantomPinned, } } @@ -252,10 +281,6 @@ impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { } } -// Required to make `pin_project` happy. -trait StaticLifetime: 'static {} -impl<T: 'static> StaticLifetime for T {} - /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). #[derive(Clone, Copy, Eq, PartialEq)] pub struct AccessError { diff --git a/src/time/clock.rs b/src/time/clock.rs index a0ff621..a44d75f 100644 --- a/src/time/clock.rs +++ b/src/time/clock.rs @@ -29,7 +29,7 @@ cfg_not_test_util! { cfg_test_util! { use crate::time::{Duration, Instant}; - use std::sync::{Arc, Mutex}; + use crate::loom::sync::{Arc, Mutex}; cfg_rt! { fn clock() -> Option<Clock> { @@ -77,6 +77,15 @@ cfg_test_util! { /// /// Panics if time is already frozen or if called from outside of a /// `current_thread` Tokio runtime. + /// + /// # Auto-advance + /// + /// If time is paused and the runtime has no work to do, the clock is + /// auto-advanced to the next pending timer. This means that [`Sleep`] or + /// other timer-backed primitives can cause the runtime to advance the + /// current time when awaited. + /// + /// [`Sleep`]: crate::time::Sleep pub fn pause() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.pause(); @@ -93,7 +102,7 @@ cfg_test_util! { /// runtime. pub fn resume() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); - let mut inner = clock.inner.lock().unwrap(); + let mut inner = clock.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -111,6 +120,12 @@ cfg_test_util! { /// /// Panics if time is not frozen or if called from outside of the Tokio /// runtime. + /// + /// # Auto-advance + /// + /// If the time is paused and there is no work to do, the runtime advances + /// time to the next timer. See [`pause`](pause#auto-advance) for more + /// details. pub async fn advance(duration: Duration) { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.advance(duration); @@ -149,7 +164,7 @@ cfg_test_util! { } pub(crate) fn pause(&self) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if !inner.enable_pausing { drop(inner); // avoid poisoning the lock @@ -163,12 +178,12 @@ cfg_test_util! { } pub(crate) fn is_paused(&self) -> bool { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); inner.unfrozen.is_none() } pub(crate) fn advance(&self, duration: Duration) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -178,7 +193,7 @@ cfg_test_util! { } pub(crate) fn now(&self) -> Instant { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); let mut ret = inner.base; diff --git a/src/time/driver/sleep.rs b/src/time/driver/sleep.rs index 8658813..40f745a 100644 --- a/src/time/driver/sleep.rs +++ b/src/time/driver/sleep.rs @@ -57,6 +57,7 @@ pub fn sleep_until(deadline: Instant) -> Sleep { /// [`interval`]: crate::time::interval() // Alias for old name in 0.x #[cfg_attr(docsrs, doc(alias = "delay_for"))] +#[cfg_attr(docsrs, doc(alias = "wait"))] pub fn sleep(duration: Duration) -> Sleep { match Instant::now().checked_add(duration) { Some(deadline) => sleep_until(deadline), diff --git a/src/time/instant.rs b/src/time/instant.rs index 1f8e663..f7cf12d 100644 --- a/src/time/instant.rs +++ b/src/time/instant.rs @@ -98,7 +98,7 @@ impl Instant { } /// Returns the amount of time elapsed from another instant to this one, or - /// zero duration if that instant is earlier than this one. + /// zero duration if that instant is later than this one. /// /// # Examples /// diff --git a/src/time/interval.rs b/src/time/interval.rs index 4b1c6f6..a63e47b 100644 --- a/src/time/interval.rs +++ b/src/time/interval.rs @@ -1,17 +1,20 @@ use crate::future::poll_fn; use crate::time::{sleep_until, Duration, Instant, Sleep}; -use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{convert::TryInto, future::Future}; -/// Creates new `Interval` that yields with interval of `duration`. The first -/// tick completes immediately. +/// Creates new [`Interval`] that yields with interval of `period`. The first +/// tick completes immediately. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). /// -/// An interval will tick indefinitely. At any time, the `Interval` value can be -/// dropped. This cancels the interval. +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. /// -/// This function is equivalent to `interval_at(Instant::now(), period)`. +/// This function is equivalent to +/// [`interval_at(Instant::now(), period)`](interval_at). /// /// # Panics /// @@ -26,9 +29,9 @@ use std::task::{Context, Poll}; /// async fn main() { /// let mut interval = time::interval(Duration::from_millis(10)); /// -/// interval.tick().await; -/// interval.tick().await; -/// interval.tick().await; +/// interval.tick().await; // ticks immediately +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms /// /// // approximately 20ms have elapsed. /// } @@ -36,10 +39,10 @@ use std::task::{Context, Poll}; /// /// A simple example using `interval` to execute a task every two seconds. /// -/// The difference between `interval` and [`sleep`] is that an `interval` -/// measures the time since the last tick, which means that `.tick().await` +/// The difference between `interval` and [`sleep`] is that an [`Interval`] +/// measures the time since the last tick, which means that [`.tick().await`] /// may wait for a shorter time than the duration specified for the interval -/// if some time has passed between calls to `.tick().await`. +/// if some time has passed between calls to [`.tick().await`]. /// /// If the tick in the example below was replaced with [`sleep`], the task /// would only be executed once every three seconds, and not every two @@ -64,17 +67,20 @@ use std::task::{Context, Poll}; /// ``` /// /// [`sleep`]: crate::time::sleep() +/// [`.tick().await`]: Interval::tick pub fn interval(period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); interval_at(Instant::now(), period) } -/// Creates new `Interval` that yields with interval of `period` with the -/// first tick completing at `start`. +/// Creates new [`Interval`] that yields with interval of `period` with the +/// first tick completing at `start`. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). /// -/// An interval will tick indefinitely. At any time, the `Interval` value can be -/// dropped. This cancels the interval. +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. /// /// # Panics /// @@ -90,9 +96,9 @@ pub fn interval(period: Duration) -> Interval { /// let start = Instant::now() + Duration::from_millis(50); /// let mut interval = interval_at(start, Duration::from_millis(10)); /// -/// interval.tick().await; -/// interval.tick().await; -/// interval.tick().await; +/// interval.tick().await; // ticks after 50ms +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms /// /// // approximately 70ms have elapsed. /// } @@ -103,19 +109,249 @@ pub fn interval_at(start: Instant, period: Duration) -> Interval { Interval { delay: Box::pin(sleep_until(start)), period, + missed_tick_behavior: Default::default(), } } -/// Interval returned by [`interval`](interval) and [`interval_at`](interval_at). +/// Defines the behavior of an [`Interval`] when it misses a tick. +/// +/// Sometimes, an [`Interval`]'s tick is missed. For example, consider the +/// following: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// # async fn task_that_takes_one_to_three_millis() {} +/// +/// #[tokio::main] +/// async fn main() { +/// // ticks every 2 seconds +/// let mut interval = time::interval(Duration::from_millis(2)); +/// for _ in 0..5 { +/// interval.tick().await; +/// // if this takes more than 2 milliseconds, a tick will be delayed +/// task_that_takes_one_to_three_millis().await; +/// } +/// } +/// ``` +/// +/// Generally, a tick is missed if too much time is spent without calling +/// [`Interval::tick()`]. +/// +/// By default, when a tick is missed, [`Interval`] fires ticks as quickly as it +/// can until it is "caught up" in time to where it should be. +/// `MissedTickBehavior` can be used to specify a different behavior for +/// [`Interval`] to exhibit. Each variant represents a different strategy. +/// +/// Note that because the executor cannot guarantee exact precision with timers, +/// these strategies will only apply when the delay is greater than 5 +/// milliseconds. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MissedTickBehavior { + /// Tick as fast as possible until caught up. + /// + /// When this strategy is used, [`Interval`] schedules ticks "normally" (the + /// same as it would have if the ticks hadn't been delayed), which results + /// in it firing ticks as fast as possible until it is caught up in time to + /// where it should be. Unlike [`Delay`] and [`Skip`], the ticks yielded + /// when `Burst` is used (the [`Instant`]s that [`tick`](Interval::tick) + /// yields) aren't different than they would have been if a tick had not + /// been missed. Like [`Skip`], and unlike [`Delay`], the ticks may be + /// shortened. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work | work | work -| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration}; + /// # async fn task_that_takes_200_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// + /// task_that_takes_200_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // Since we are more than 100ms after the start of `interval`, this will + /// // also resolve immediately. + /// interval.tick().await; + /// + /// // Also resolves immediately, because it was supposed to resolve at + /// // 150ms after the start of `interval` + /// interval.tick().await; + /// + /// // Resolves immediately + /// interval.tick().await; + /// + /// // Since we have gotten to 200ms after the start of `interval`, this + /// // will resolve after 50ms + /// interval.tick().await; + /// # } + /// ``` + /// + /// This is the default behavior when [`Interval`] is created with + /// [`interval`] and [`interval_at`]. + /// + /// [`Delay`]: MissedTickBehavior::Delay + /// [`Skip`]: MissedTickBehavior::Skip + Burst, + + /// Tick at multiples of `period` from when [`tick`] was called, rather than + /// from `start`. + /// + /// When this strategy is used and [`Interval`] has missed a tick, instead + /// of scheduling ticks to fire at multiples of `period` from `start` (the + /// time when the first tick was fired), it schedules all future ticks to + /// happen at a regular `period` from the point when [`tick`] was called. + /// Unlike [`Burst`] and [`Skip`], ticks are not shortened, and they aren't + /// guaranteed to happen at a multiple of `period` from `start` any longer. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work -----| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_more_than_50_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + /// + /// task_that_takes_more_than_50_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // But this one, rather than also resolving immediately, as might happen + /// // with the `Burst` or `Skip` behaviors, will not resolve until + /// // 50ms after the call to `tick` up above. That is, in `tick`, when we + /// // recognize that we missed a tick, we schedule the next tick to happen + /// // 50ms (or whatever the `period` is) from right then, not from when + /// // were were *supposed* to tick + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Skip`]: MissedTickBehavior::Skip + /// [`tick`]: Interval::tick + Delay, + + /// Skip missed ticks and tick on the next multiple of `period` from + /// `start`. + /// + /// When this strategy is used, [`Interval`] schedules the next tick to fire + /// at the next-closest tick that is a multiple of `period` away from + /// `start` (the point where [`Interval`] first ticked). Like [`Burst`], all + /// ticks remain multiples of `period` away from `start`, but unlike + /// [`Burst`], the ticks may not be *one* multiple of `period` away from the + /// last tick. Like [`Delay`], the ticks are no longer the same as they + /// would have been if ticks had not been missed, but unlike [`Delay`], and + /// like [`Burst`], the ticks may be shortened to be less than one `period` + /// away from each other. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work ---| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_75_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + /// + /// task_that_takes_75_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // This one will resolve after 25ms, 100ms after the start of + /// // `interval`, which is the closest multiple of `period` from the start + /// // of `interval` after the call to `tick` up above. + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Delay`]: MissedTickBehavior::Delay + Skip, +} + +impl MissedTickBehavior { + /// If a tick is missed, this method is called to determine when the next tick should happen. + fn next_timeout(&self, timeout: Instant, now: Instant, period: Duration) -> Instant { + match self { + Self::Burst => timeout + period, + Self::Delay => now + period, + Self::Skip => { + now + period + - Duration::from_nanos( + ((now - timeout).as_nanos() % period.as_nanos()) + .try_into() + // This operation is practically guaranteed not to + // fail, as in order for it to fail, `period` would + // have to be longer than `now - timeout`, and both + // would have to be longer than 584 years. + // + // If it did fail, there's not a good way to pass + // the error along to the user, so we just panic. + .expect( + "too much time has elapsed since the interval was supposed to tick", + ), + ) + } + } + } +} + +impl Default for MissedTickBehavior { + /// Returns [`MissedTickBehavior::Burst`]. + /// + /// For most usecases, the [`Burst`] strategy is what is desired. + /// Additionally, to preserve backwards compatibility, the [`Burst`] + /// strategy must be the default. For these reasons, + /// [`MissedTickBehavior::Burst`] is the default for [`MissedTickBehavior`]. + /// See [`Burst`] for more details. + /// + /// [`Burst`]: MissedTickBehavior::Burst + fn default() -> Self { + Self::Burst + } +} + +/// Interval returned by [`interval`] and [`interval_at`] /// /// This type allows you to wait on a sequence of instants with a certain -/// duration between each instant. Unlike calling [`sleep`](crate::time::sleep) -/// in a loop, this lets you count the time spent between the calls to `sleep` -/// as well. +/// duration between each instant. Unlike calling [`sleep`] in a loop, this lets +/// you count the time spent between the calls to [`sleep`] as well. /// /// An `Interval` can be turned into a `Stream` with [`IntervalStream`]. /// -/// [`IntervalStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.IntervalStream.html +/// [`IntervalStream`]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.IntervalStream.html +/// [`sleep`]: crate::time::sleep #[derive(Debug)] pub struct Interval { /// Future that completes the next time the `Interval` yields a value. @@ -123,6 +359,9 @@ pub struct Interval { /// The duration between values yielded by `Interval`. period: Duration, + + /// The strategy `Interval` should use when a tick is missed. + missed_tick_behavior: MissedTickBehavior, } impl Interval { @@ -159,22 +398,46 @@ impl Interval { /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when the instant has elapsed. Note that on multiple - /// calls to `poll_tick`, only the `Waker` from the `Context` passed to the - /// most recent call is scheduled to receive a wakeup. + /// calls to `poll_tick`, only the [`Waker`](std::task::Waker) from the + /// [`Context`] passed to the most recent call is scheduled to receive a + /// wakeup. pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { // Wait for the delay to be done ready!(Pin::new(&mut self.delay).poll(cx)); - // Get the `now` by looking at the `delay` deadline - let now = self.delay.deadline(); + // Get the time when we were scheduled to tick + let timeout = self.delay.deadline(); + + let now = Instant::now(); + + // If a tick was not missed, and thus we are being called before the + // next tick is due, just schedule the next tick normally, one `period` + // after `timeout` + // + // However, if a tick took excessively long and we are now behind, + // schedule the next tick according to how the user specified with + // `MissedTickBehavior` + let next = if now > timeout + Duration::from_millis(5) { + self.missed_tick_behavior + .next_timeout(timeout, now, self.period) + } else { + timeout + self.period + }; - // The next interval value is `duration` after the one that just - // yielded. - let next = now + self.period; self.delay.as_mut().reset(next); - // Return the current instant - Poll::Ready(now) + // Return the time when we were scheduled to tick + Poll::Ready(timeout) + } + + /// Returns the [`MissedTickBehavior`] strategy currently being used. + pub fn missed_tick_behavior(&self) -> MissedTickBehavior { + self.missed_tick_behavior + } + + /// Sets the [`MissedTickBehavior`] strategy that should be used. + pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) { + self.missed_tick_behavior = behavior; } /// Returns the period of the interval. diff --git a/src/time/mod.rs b/src/time/mod.rs index 98bb2af..281990e 100644 --- a/src/time/mod.rs +++ b/src/time/mod.rs @@ -3,21 +3,21 @@ //! This module provides a number of types for executing code after a set period //! of time. //! -//! * `Sleep` is a future that does no work and completes at a specific `Instant` +//! * [`Sleep`] is a future that does no work and completes at a specific [`Instant`] //! in time. //! -//! * `Interval` is a stream yielding a value at a fixed period. It is -//! initialized with a `Duration` and repeatedly yields each time the duration +//! * [`Interval`] is a stream yielding a value at a fixed period. It is +//! initialized with a [`Duration`] and repeatedly yields each time the duration //! elapses. //! -//! * `Timeout`: Wraps a future or stream, setting an upper bound to the amount +//! * [`Timeout`]: Wraps a future or stream, setting an upper bound to the amount //! of time it is allowed to execute. If the future or stream does not //! complete in time, then it is canceled and an error is returned. //! //! These types are sufficient for handling a large number of scenarios //! involving time. //! -//! These types must be used from within the context of the `Runtime`. +//! These types must be used from within the context of the [`Runtime`](crate::runtime::Runtime). //! //! # Examples //! @@ -55,8 +55,8 @@ //! A simple example using [`interval`] to execute a task every two seconds. //! //! The difference between [`interval`] and [`sleep`] is that an [`interval`] -//! measures the time since the last tick, which means that `.tick().await` -//! may wait for a shorter time than the duration specified for the interval +//! measures the time since the last tick, which means that `.tick().await` may +//! wait for a shorter time than the duration specified for the interval //! if some time has passed between calls to `.tick().await`. //! //! If the tick in the example below was replaced with [`sleep`], the task @@ -81,7 +81,6 @@ //! } //! ``` //! -//! [`sleep`]: crate::time::sleep() //! [`interval`]: crate::time::interval() mod clock; @@ -100,7 +99,7 @@ mod instant; pub use self::instant::Instant; mod interval; -pub use interval::{interval, interval_at, Interval}; +pub use interval::{interval, interval_at, Interval, MissedTickBehavior}; mod timeout; #[doc(inline)] diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs index a74f562..1eab81c 100644 --- a/src/util/linked_list.rs +++ b/src/util/linked_list.rs @@ -236,37 +236,6 @@ impl<L: Link> Default for LinkedList<L, L::Target> { } } -// ===== impl Iter ===== - -cfg_rt_multi_thread! { - pub(crate) struct Iter<'a, T: Link> { - curr: Option<NonNull<T::Target>>, - _p: core::marker::PhantomData<&'a T>, - } - - impl<L: Link> LinkedList<L, L::Target> { - pub(crate) fn iter(&self) -> Iter<'_, L> { - Iter { - curr: self.head, - _p: core::marker::PhantomData, - } - } - } - - impl<'a, T: Link> Iterator for Iter<'a, T> { - type Item = &'a T::Target; - - fn next(&mut self) -> Option<&'a T::Target> { - let curr = self.curr?; - // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.get_next(); - - // safety: the value is still owned by the linked list. - Some(unsafe { &*curr.as_ptr() }) - } - } -} - // ===== impl DrainFilter ===== cfg_io_readiness! { @@ -645,24 +614,6 @@ mod tests { } } - #[test] - fn iter() { - let a = entry(5); - let b = entry(7); - - let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); - - assert_eq!(0, list.iter().count()); - - list.push_front(a.as_ref()); - list.push_front(b.as_ref()); - - let mut i = list.iter(); - assert_eq!(7, i.next().unwrap().val); - assert_eq!(5, i.next().unwrap().val); - assert!(i.next().is_none()); - } - proptest::proptest! { #[test] fn fuzz_linked_list(ops: Vec<usize>) { diff --git a/src/util/mod.rs b/src/util/mod.rs index b267125..9065f50 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -21,6 +21,12 @@ cfg_rt! { mod wake; pub(crate) use wake::WakerRef; pub(crate) use wake::{waker_ref, Wake}; + + mod sync_wrapper; + pub(crate) use sync_wrapper::SyncWrapper; + + mod vec_deque_cell; + pub(crate) use vec_deque_cell::VecDequeCell; } cfg_rt_multi_thread! { diff --git a/src/util/sync_wrapper.rs b/src/util/sync_wrapper.rs new file mode 100644 index 0000000..5ffc8f9 --- /dev/null +++ b/src/util/sync_wrapper.rs @@ -0,0 +1,26 @@ +//! This module contains a type that can make `Send + !Sync` types `Sync` by +//! disallowing all immutable access to the value. +//! +//! A similar primitive is provided in the `sync_wrapper` crate. + +pub(crate) struct SyncWrapper<T> { + value: T, +} + +// safety: The SyncWrapper being send allows you to send the inner value across +// thread boundaries. +unsafe impl<T: Send> Send for SyncWrapper<T> {} + +// safety: An immutable reference to a SyncWrapper is useless, so moving such an +// immutable reference across threads is safe. +unsafe impl<T> Sync for SyncWrapper<T> {} + +impl<T> SyncWrapper<T> { + pub(crate) fn new(value: T) -> Self { + Self { value } + } + + pub(crate) fn into_inner(self) -> T { + self.value + } +} diff --git a/src/util/trace.rs b/src/util/trace.rs index 96a9db9..c51a5a7 100644 --- a/src/util/trace.rs +++ b/src/util/trace.rs @@ -4,7 +4,7 @@ cfg_trace! { #[inline] #[cfg_attr(tokio_track_caller, track_caller)] - pub(crate) fn task<F>(task: F, kind: &'static str) -> Instrumented<F> { + pub(crate) fn task<F>(task: F, kind: &'static str, name: Option<&str>) -> Instrumented<F> { use tracing::instrument::Instrument; #[cfg(tokio_track_caller)] let location = std::panic::Location::caller(); @@ -14,12 +14,14 @@ cfg_trace! { "task", %kind, spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), + task.name = %name.unwrap_or_default() ); #[cfg(not(tokio_track_caller))] let span = tracing::trace_span!( target: "tokio::task", "task", %kind, + task.name = %name.unwrap_or_default() ); task.instrument(span) } @@ -29,7 +31,7 @@ cfg_trace! { cfg_not_trace! { cfg_rt! { #[inline] - pub(crate) fn task<F>(task: F, _: &'static str) -> F { + pub(crate) fn task<F>(task: F, _: &'static str, _name: Option<&str>) -> F { // nop task } diff --git a/src/util/vec_deque_cell.rs b/src/util/vec_deque_cell.rs new file mode 100644 index 0000000..12883ab --- /dev/null +++ b/src/util/vec_deque_cell.rs @@ -0,0 +1,53 @@ +use crate::loom::cell::UnsafeCell; + +use std::collections::VecDeque; +use std::marker::PhantomData; + +/// This type is like VecDeque, except that it is not Sync and can be modified +/// through immutable references. +pub(crate) struct VecDequeCell<T> { + inner: UnsafeCell<VecDeque<T>>, + _not_sync: PhantomData<*const ()>, +} + +// This is Send for the same reasons that RefCell<VecDeque<T>> is Send. +unsafe impl<T: Send> Send for VecDequeCell<T> {} + +impl<T> VecDequeCell<T> { + pub(crate) fn with_capacity(cap: usize) -> Self { + Self { + inner: UnsafeCell::new(VecDeque::with_capacity(cap)), + _not_sync: PhantomData, + } + } + + /// Safety: This method may not be called recursively. + #[inline] + unsafe fn with_inner<F, R>(&self, f: F) -> R + where + F: FnOnce(&mut VecDeque<T>) -> R, + { + // safety: This type is not Sync, so concurrent calls of this method + // cannot happen. Furthermore, the caller guarantees that the method is + // not called recursively. Finally, this is the only place that can + // create mutable references to the inner VecDeque. This ensures that + // any mutable references created here are exclusive. + self.inner.with_mut(|ptr| f(&mut *ptr)) + } + + pub(crate) fn pop_front(&self) -> Option<T> { + unsafe { self.with_inner(VecDeque::pop_front) } + } + + pub(crate) fn push_back(&self, item: T) { + unsafe { + self.with_inner(|inner| inner.push_back(item)); + } + } + + /// Replace the inner VecDeque with an empty VecDeque and return the current + /// contents. + pub(crate) fn take(&self) -> VecDeque<T> { + unsafe { self.with_inner(|inner| std::mem::take(inner)) } + } +} diff --git a/tests/async_send_sync.rs b/tests/async_send_sync.rs index 01e6081..aa14970 100644 --- a/tests/async_send_sync.rs +++ b/tests/async_send_sync.rs @@ -4,13 +4,30 @@ use std::cell::Cell; use std::future::Future; -use std::io::{Cursor, SeekFrom}; +use std::io::SeekFrom; use std::net::SocketAddr; use std::pin::Pin; use std::rc::Rc; use tokio::net::TcpStream; use tokio::time::{Duration, Instant}; +// The names of these structs behaves better when sorted. +// Send: Yes, Sync: Yes +#[derive(Clone)] +struct YY {} + +// Send: Yes, Sync: No +#[derive(Clone)] +struct YN { + _value: Cell<u8>, +} + +// Send: No, Sync: No +#[derive(Clone)] +struct NN { + _value: Rc<u8>, +} + #[allow(dead_code)] type BoxFutureSync<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + Sync>>; #[allow(dead_code)] @@ -19,11 +36,11 @@ type BoxFutureSend<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T>>>; #[allow(dead_code)] -type BoxAsyncRead = std::pin::Pin<Box<dyn tokio::io::AsyncBufRead>>; +type BoxAsyncRead = std::pin::Pin<Box<dyn tokio::io::AsyncBufRead + Send + Sync>>; #[allow(dead_code)] -type BoxAsyncSeek = std::pin::Pin<Box<dyn tokio::io::AsyncSeek>>; +type BoxAsyncSeek = std::pin::Pin<Box<dyn tokio::io::AsyncSeek + Send + Sync>>; #[allow(dead_code)] -type BoxAsyncWrite = std::pin::Pin<Box<dyn tokio::io::AsyncWrite>>; +type BoxAsyncWrite = std::pin::Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>; #[allow(dead_code)] fn require_send<T: Send>(_t: &T) {} @@ -59,310 +76,594 @@ macro_rules! into_todo { x }}; } -macro_rules! assert_value { - ($type:ty: Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_send(&f); - require_sync(&f); - }; - }; - ($type:ty: !Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - AmbiguousIfSend::some_item(&f); - require_sync(&f); - }; - }; - ($type:ty: Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_send(&f); - AmbiguousIfSync::some_item(&f); - }; + +macro_rules! async_assert_fn_send { + (Send & $(!)?Sync & $(!)?Unpin, $value:expr) => { + require_send(&$value); }; - ($type:ty: !Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - AmbiguousIfSend::some_item(&f); - AmbiguousIfSync::some_item(&f); - }; - }; - ($type:ty: Unpin) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_unpin(&f); - }; + (!Send & $(!)?Sync & $(!)?Unpin, $value:expr) => { + AmbiguousIfSend::some_item(&$value); }; } -macro_rules! async_assert_fn { - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_send(&f); - require_sync(&f); - }; +macro_rules! async_assert_fn_sync { + ($(!)?Send & Sync & $(!)?Unpin, $value:expr) => { + require_sync(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_send(&f); - AmbiguousIfSync::some_item(&f); - }; + ($(!)?Send & !Sync & $(!)?Unpin, $value:expr) => { + AmbiguousIfSync::some_item(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfSend::some_item(&f); - require_sync(&f); - }; +} +macro_rules! async_assert_fn_unpin { + ($(!)?Send & $(!)?Sync & Unpin, $value:expr) => { + require_unpin(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfSend::some_item(&f); - AmbiguousIfSync::some_item(&f); - }; + ($(!)?Send & $(!)?Sync & !Unpin, $value:expr) => { + AmbiguousIfUnpin::some_item(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Unpin) => { +} + +macro_rules! async_assert_fn { + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): $($tok:tt)*) => { #[allow(unreachable_code)] #[allow(unused_variables)] const _: fn() = || { let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfUnpin::some_item(&f); + async_assert_fn_send!($($tok)*, f); + async_assert_fn_sync!($($tok)*, f); + async_assert_fn_unpin!($($tok)*, f); }; }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Unpin) => { +} +macro_rules! assert_value { + ($type:ty: $($tok:tt)*) => { #[allow(unreachable_code)] #[allow(unused_variables)] const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_unpin(&f); + let f: $type = todo!(); + async_assert_fn_send!($($tok)*, f); + async_assert_fn_sync!($($tok)*, f); + async_assert_fn_unpin!($($tok)*, f); }; }; } -async_assert_fn!(tokio::io::copy(&mut TcpStream, &mut TcpStream): Send & Sync); -async_assert_fn!(tokio::io::empty(): Send & Sync); -async_assert_fn!(tokio::io::repeat(u8): Send & Sync); -async_assert_fn!(tokio::io::sink(): Send & Sync); -async_assert_fn!(tokio::io::split(TcpStream): Send & Sync); -async_assert_fn!(tokio::io::stderr(): Send & Sync); -async_assert_fn!(tokio::io::stdin(): Send & Sync); -async_assert_fn!(tokio::io::stdout(): Send & Sync); -async_assert_fn!(tokio::io::Split<Cursor<Vec<u8>>>::next_segment(_): Send & Sync); - -async_assert_fn!(tokio::fs::canonicalize(&str): Send & Sync); -async_assert_fn!(tokio::fs::copy(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::create_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::create_dir_all(&str): Send & Sync); -async_assert_fn!(tokio::fs::hard_link(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::metadata(&str): Send & Sync); -async_assert_fn!(tokio::fs::read(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_link(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_to_string(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_dir_all(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_file(&str): Send & Sync); -async_assert_fn!(tokio::fs::rename(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::set_permissions(&str, std::fs::Permissions): Send & Sync); -async_assert_fn!(tokio::fs::symlink_metadata(&str): Send & Sync); -async_assert_fn!(tokio::fs::write(&str, Vec<u8>): Send & Sync); -async_assert_fn!(tokio::fs::ReadDir::next_entry(_): Send & Sync); -async_assert_fn!(tokio::fs::OpenOptions::open(_, &str): Send & Sync); -async_assert_fn!(tokio::fs::DirEntry::metadata(_): Send & Sync); -async_assert_fn!(tokio::fs::DirEntry::file_type(_): Send & Sync); +assert_value!(tokio::fs::DirBuilder: Send & Sync & Unpin); +assert_value!(tokio::fs::DirEntry: Send & Sync & Unpin); +assert_value!(tokio::fs::File: Send & Sync & Unpin); +assert_value!(tokio::fs::OpenOptions: Send & Sync & Unpin); +assert_value!(tokio::fs::ReadDir: Send & Sync & Unpin); -async_assert_fn!(tokio::fs::File::open(&str): Send & Sync); -async_assert_fn!(tokio::fs::File::create(&str): Send & Sync); -async_assert_fn!(tokio::fs::File::sync_all(_): Send & Sync); -async_assert_fn!(tokio::fs::File::sync_data(_): Send & Sync); -async_assert_fn!(tokio::fs::File::set_len(_, u64): Send & Sync); -async_assert_fn!(tokio::fs::File::metadata(_): Send & Sync); -async_assert_fn!(tokio::fs::File::try_clone(_): Send & Sync); -async_assert_fn!(tokio::fs::File::into_std(_): Send & Sync); -async_assert_fn!(tokio::fs::File::set_permissions(_, std::fs::Permissions): Send & Sync); +async_assert_fn!(tokio::fs::canonicalize(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::copy(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::create_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::create_dir_all(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::hard_link(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::metadata(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_link(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_to_string(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_dir_all(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_file(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::rename(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::set_permissions(&str, std::fs::Permissions): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::symlink_metadata(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::write(&str, Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::ReadDir::next_entry(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::OpenOptions::open(_, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirBuilder::create(_, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirEntry::metadata(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirEntry::file_type(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::open(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::create(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::sync_all(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::sync_data(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::set_len(_, u64): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::metadata(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::try_clone(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::into_std(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::set_permissions(_, std::fs::Permissions): Send & Sync & !Unpin); -async_assert_fn!(tokio::net::lookup_host(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpListener::bind(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpListener::accept(_): Send & Sync); -async_assert_fn!(tokio::net::TcpStream::connect(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpStream::peek(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::tcp::ReadHalf::peek(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::bind(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::connect(_, SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::send(_, &[u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::recv(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::send_to(_, &[u8], SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::recv_from(_, &mut [u8]): Send & Sync); +assert_value!(tokio::net::TcpListener: Send & Sync & Unpin); +assert_value!(tokio::net::TcpSocket: Send & Sync & Unpin); +assert_value!(tokio::net::TcpStream: Send & Sync & Unpin); +assert_value!(tokio::net::UdpSocket: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::OwnedReadHalf: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::OwnedWriteHalf: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::ReadHalf<'_>: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::ReuniteError: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::WriteHalf<'_>: Send & Sync & Unpin); +async_assert_fn!(tokio::net::TcpListener::accept(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpListener::bind(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::connect(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::peek(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::readable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::ready(_, tokio::io::Interest): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::writable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::bind(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::connect(_, SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::peek_from(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::readable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::ready(_, tokio::io::Interest): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::recv(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::recv_from(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::send(_, &[u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::send_to(_, &[u8], SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::writable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::lookup_host(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::tcp::ReadHalf::peek(_, &mut [u8]): Send & Sync & !Unpin); #[cfg(unix)] mod unix_datagram { use super::*; - async_assert_fn!(tokio::net::UnixListener::bind(&str): Send & Sync); - async_assert_fn!(tokio::net::UnixListener::accept(_): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::send(_, &[u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::recv(_, &mut [u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::send_to(_, &[u8], &str): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::recv_from(_, &mut [u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixStream::connect(&str): Send & Sync); + use tokio::net::*; + assert_value!(UnixDatagram: Send & Sync & Unpin); + assert_value!(UnixListener: Send & Sync & Unpin); + assert_value!(UnixStream: Send & Sync & Unpin); + assert_value!(unix::OwnedReadHalf: Send & Sync & Unpin); + assert_value!(unix::OwnedWriteHalf: Send & Sync & Unpin); + assert_value!(unix::ReadHalf<'_>: Send & Sync & Unpin); + assert_value!(unix::ReuniteError: Send & Sync & Unpin); + assert_value!(unix::SocketAddr: Send & Sync & Unpin); + assert_value!(unix::UCred: Send & Sync & Unpin); + assert_value!(unix::WriteHalf<'_>: Send & Sync & Unpin); + async_assert_fn!(UnixDatagram::readable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::recv(_, &mut [u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::recv_from(_, &mut [u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::send(_, &[u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::send_to(_, &[u8], &str): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::writable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixListener::accept(_): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::connect(&str): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::readable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::writable(_): Send & Sync & !Unpin); } -async_assert_fn!(tokio::process::Child::wait_with_output(_): Send & Sync); -async_assert_fn!(tokio::signal::ctrl_c(): Send & Sync); -#[cfg(unix)] -async_assert_fn!(tokio::signal::unix::Signal::recv(_): Send & Sync); +#[cfg(windows)] +mod windows_named_pipe { + use super::*; + use tokio::net::windows::named_pipe::*; + assert_value!(ClientOptions: Send & Sync & Unpin); + assert_value!(NamedPipeClient: Send & Sync & Unpin); + assert_value!(NamedPipeServer: Send & Sync & Unpin); + assert_value!(PipeEnd: Send & Sync & Unpin); + assert_value!(PipeInfo: Send & Sync & Unpin); + assert_value!(PipeMode: Send & Sync & Unpin); + assert_value!(ServerOptions: Send & Sync & Unpin); + async_assert_fn!(NamedPipeClient::readable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeClient::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeClient::writable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::connect(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::readable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::writable(_): Send & Sync & !Unpin); +} -async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<u8>::lock(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Cell<u8>>::lock(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Rc<u8>>::lock(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Mutex<u8>::lock_owned(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Cell<u8>>::lock_owned(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Rc<u8>>::lock_owned(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Notify::notified(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<u8>::read(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<u8>::write(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<Cell<u8>>::read(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Cell<u8>>::write(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Rc<u8>>::read(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Rc<u8>>::write(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Semaphore::acquire(_): Send & Sync); +assert_value!(tokio::process::Child: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStderr: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStdin: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStdout: Send & Sync & Unpin); +assert_value!(tokio::process::Command: Send & Sync & Unpin); +async_assert_fn!(tokio::process::Child::kill(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::process::Child::wait(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::process::Child::wait_with_output(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::sync::broadcast::Receiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::broadcast::Receiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::broadcast::Receiver<Rc<u8>>::recv(_): !Send & !Sync); +async_assert_fn!(tokio::signal::ctrl_c(): Send & Sync & !Unpin); +#[cfg(unix)] +mod unix_signal { + use super::*; + assert_value!(tokio::signal::unix::Signal: Send & Sync & Unpin); + assert_value!(tokio::signal::unix::SignalKind: Send & Sync & Unpin); + async_assert_fn!(tokio::signal::unix::Signal::recv(_): Send & Sync & !Unpin); +} +#[cfg(windows)] +mod windows_signal { + use super::*; + assert_value!(tokio::signal::windows::CtrlC: Send & Sync & Unpin); + assert_value!(tokio::signal::windows::CtrlBreak: Send & Sync & Unpin); + async_assert_fn!(tokio::signal::windows::CtrlC::recv(_): Send & Sync & !Unpin); + async_assert_fn!(tokio::signal::windows::CtrlBreak::recv(_): Send & Sync & !Unpin); +} -async_assert_fn!(tokio::sync::mpsc::Receiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Receiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Receiver<Rc<u8>>::recv(_): !Send & !Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<u8>::send(_, u8): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<Cell<u8>>::send(_, Cell<u8>): Send & !Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<Rc<u8>>::send(_, Rc<u8>): !Send & !Sync); +assert_value!(tokio::sync::AcquireError: Send & Sync & Unpin); +assert_value!(tokio::sync::Barrier: Send & Sync & Unpin); +assert_value!(tokio::sync::BarrierWaitResult: Send & Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Mutex<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::Mutex<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::Mutex<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Notify: Send & Sync & Unpin); +assert_value!(tokio::sync::OnceCell<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OnceCell<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::OnceCell<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedSemaphorePermit: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLock<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLock<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLock<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Semaphore: Send & Sync & Unpin); +assert_value!(tokio::sync::SemaphorePermit<'_>: Send & Sync & Unpin); +assert_value!(tokio::sync::TryAcquireError: Send & Sync & Unpin); +assert_value!(tokio::sync::TryLockError: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::futures::Notified<'_>: Send & Sync & !Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, YY>: !Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<YY>: Send & Sync & Unpin); +async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<NN>::lock(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<NN>::lock_owned(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YN>::lock(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YN>::lock_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YY>::lock(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YY>::lock_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Notify::notified(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY> + Send + Sync>>): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY> + Send>>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>> + Send + Sync>>): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>> + Send>>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<NN>::read(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<NN>::write(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YN>::read(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YN>::write(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YY>::read(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YY>::write(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_many(_, u32): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_many_owned(_, u32): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::reserve(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::reserve_owned(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::send(_, NN): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::send_timeout(_, NN, Duration): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::reserve(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::reserve_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::send(_, YN): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::send_timeout(_, YN, Duration): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::reserve(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::reserve_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::send(_, YY): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::send_timeout(_, YY, Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<NN>::changed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<YN>::changed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<YY>::changed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<YN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<YY>::closed(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<Rc<u8>>::recv(_): !Send & !Sync); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFuture<()>): !Send & !Sync & Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFutureSend<()>): Send & !Sync & Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFutureSync<()>): Send & Sync & Unpin); +assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin); +assert_value!(tokio::task::JoinHandle<YY>: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle<YN>: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::task::JoinError: Send & Sync & Unpin); -async_assert_fn!(tokio::sync::watch::Receiver<u8>::changed(_): Send & Sync); -async_assert_fn!(tokio::sync::watch::Sender<u8>::closed(_): Send & Sync); -async_assert_fn!(tokio::sync::watch::Sender<Cell<u8>>::closed(_): !Send & !Sync); -async_assert_fn!(tokio::sync::watch::Sender<Rc<u8>>::closed(_): !Send & !Sync); +assert_value!(tokio::runtime::Builder: Send & Sync & Unpin); +assert_value!(tokio::runtime::EnterGuard<'_>: Send & Sync & Unpin); +assert_value!(tokio::runtime::Handle: Send & Sync & Unpin); +assert_value!(tokio::runtime::Runtime: Send & Sync & Unpin); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8> + Send + Sync>>): Send & Sync); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8> + Send>>): Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8>>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>> + Send + Sync>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>> + Send>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>>>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>> + Send + Sync>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>> + Send>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>>>>): !Send & !Sync); -assert_value!(tokio::sync::OnceCell<u8>: Send & Sync); -assert_value!(tokio::sync::OnceCell<Cell<u8>>: Send & !Sync); -assert_value!(tokio::sync::OnceCell<Rc<u8>>: !Send & !Sync); +assert_value!(tokio::time::Interval: Send & Sync & Unpin); +assert_value!(tokio::time::Instant: Send & Sync & Unpin); +assert_value!(tokio::time::Sleep: Send & Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFutureSync<()>>: Send & Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFutureSend<()>>: Send & !Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFuture<()>>: !Send & !Sync & !Unpin); +assert_value!(tokio::time::error::Elapsed: Send & Sync & Unpin); +assert_value!(tokio::time::error::Error: Send & Sync & Unpin); +async_assert_fn!(tokio::time::advance(Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::sleep(Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::sleep_until(Instant): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync); -assert_value!(tokio::task::LocalSet: !Send & !Sync); +assert_value!(tokio::io::BufReader<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::BufStream<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::BufWriter<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::DuplexStream: Send & Sync & Unpin); +assert_value!(tokio::io::Empty: Send & Sync & Unpin); +assert_value!(tokio::io::Interest: Send & Sync & Unpin); +assert_value!(tokio::io::Lines<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::ReadBuf<'_>: Send & Sync & Unpin); +assert_value!(tokio::io::ReadHalf<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::Ready: Send & Sync & Unpin); +assert_value!(tokio::io::Repeat: Send & Sync & Unpin); +assert_value!(tokio::io::Sink: Send & Sync & Unpin); +assert_value!(tokio::io::Split<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::Stderr: Send & Sync & Unpin); +assert_value!(tokio::io::Stdin: Send & Sync & Unpin); +assert_value!(tokio::io::Stdout: Send & Sync & Unpin); +assert_value!(tokio::io::Take<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::WriteHalf<TcpStream>: Send & Sync & Unpin); +async_assert_fn!(tokio::io::copy(&mut TcpStream, &mut TcpStream): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::copy_bidirectional(&mut TcpStream, &mut TcpStream): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::copy_buf(&mut tokio::io::BufReader<TcpStream>, &mut TcpStream): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::empty(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::repeat(u8): Send & Sync & Unpin); +async_assert_fn!(tokio::io::sink(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::split(TcpStream): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stderr(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stdin(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stdout(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::Split<tokio::io::BufReader<TcpStream>>::next_segment(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::Lines<tokio::io::BufReader<TcpStream>>::next_line(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncBufReadExt::fill_buf(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read(&mut BoxAsyncRead, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_buf(&mut BoxAsyncRead, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncReadExt::read_exact(&mut BoxAsyncRead, &mut [u8]): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncReadExt::read_u8(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i8(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u16(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i16(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u128(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i128(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u16_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i16_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u128_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i128_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_to_end(&mut BoxAsyncRead, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncReadExt::read_to_string(&mut BoxAsyncRead, &mut String): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncSeekExt::seek(&mut BoxAsyncSeek, SeekFrom): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncSeekExt::stream_position(&mut BoxAsyncSeek): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::write(&mut BoxAsyncWrite, &[u8]): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_vectored(&mut BoxAsyncWrite, _): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_buf(&mut BoxAsyncWrite, &mut bytes::Bytes): Send + & Sync + & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_all_buf(&mut BoxAsyncWrite, &mut bytes::Bytes): Send + & Sync + & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_all(&mut BoxAsyncWrite, &[u8]): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncWriteExt::write_u8(&mut BoxAsyncWrite, u8): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::write_i8(&mut BoxAsyncWrite, i8): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u16(&mut BoxAsyncWrite, u16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i16(&mut BoxAsyncWrite, i16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u32(&mut BoxAsyncWrite, u32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i32(&mut BoxAsyncWrite, i32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u64(&mut BoxAsyncWrite, u64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i64(&mut BoxAsyncWrite, i64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u128(&mut BoxAsyncWrite, u128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i128(&mut BoxAsyncWrite, i128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f32(&mut BoxAsyncWrite, f32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f64(&mut BoxAsyncWrite, f64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u16_le(&mut BoxAsyncWrite, u16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i16_le(&mut BoxAsyncWrite, i16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u32_le(&mut BoxAsyncWrite, u32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i32_le(&mut BoxAsyncWrite, i32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u64_le(&mut BoxAsyncWrite, u64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i64_le(&mut BoxAsyncWrite, i64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u128_le(&mut BoxAsyncWrite, u128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i128_le(&mut BoxAsyncWrite, i128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f32_le(&mut BoxAsyncWrite, f32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f64_le(&mut BoxAsyncWrite, f64): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncWriteExt::flush(&mut BoxAsyncWrite): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::shutdown(&mut BoxAsyncWrite): Send & Sync & !Unpin); -async_assert_fn!(tokio::time::advance(Duration): Send & Sync); -async_assert_fn!(tokio::time::sleep(Duration): Send & Sync); -async_assert_fn!(tokio::time::sleep_until(Instant): Send & Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync); +#[cfg(unix)] +mod unix_asyncfd { + use super::*; + use tokio::io::unix::*; -assert_value!(tokio::time::Interval: Unpin); -async_assert_fn!(tokio::time::sleep(Duration): !Unpin); -async_assert_fn!(tokio::time::sleep_until(Instant): !Unpin); -async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Unpin); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Unpin); -async_assert_fn!(tokio::time::Interval::tick(_): !Unpin); -async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec<u8>): !Unpin); -async_assert_fn!(tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read(&mut BoxAsyncRead, &mut [u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_exact(&mut BoxAsyncRead, &mut [u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u8(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i8(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u16(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i16(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u32(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i32(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u64(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i64(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u128(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i128(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u16_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i16_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u32_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i32_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u64_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i64_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u128_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i128_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_to_end(&mut BoxAsyncRead, &mut Vec<u8>): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_to_string(&mut BoxAsyncRead, &mut String): !Unpin); -async_assert_fn!(tokio::io::AsyncSeekExt::seek(&mut BoxAsyncSeek, SeekFrom): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write(&mut BoxAsyncWrite, &[u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_all(&mut BoxAsyncWrite, &[u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u8(&mut BoxAsyncWrite, u8): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i8(&mut BoxAsyncWrite, i8): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u16(&mut BoxAsyncWrite, u16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i16(&mut BoxAsyncWrite, i16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u32(&mut BoxAsyncWrite, u32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i32(&mut BoxAsyncWrite, i32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u64(&mut BoxAsyncWrite, u64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i64(&mut BoxAsyncWrite, i64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u128(&mut BoxAsyncWrite, u128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i128(&mut BoxAsyncWrite, i128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u16_le(&mut BoxAsyncWrite, u16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i16_le(&mut BoxAsyncWrite, i16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u32_le(&mut BoxAsyncWrite, u32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i32_le(&mut BoxAsyncWrite, i32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u64_le(&mut BoxAsyncWrite, u64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i64_le(&mut BoxAsyncWrite, i64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u128_le(&mut BoxAsyncWrite, u128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i128_le(&mut BoxAsyncWrite, i128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::flush(&mut BoxAsyncWrite): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::shutdown(&mut BoxAsyncWrite): !Unpin); + struct ImplsFd<T> { + _t: T, + } + impl<T> std::os::unix::io::AsRawFd for ImplsFd<T> { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + unreachable!() + } + } + + assert_value!(AsyncFd<ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFd<ImplsFd<YN>>: Send & !Sync & Unpin); + assert_value!(AsyncFd<ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<YN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<YN>>: Send & !Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(TryIoError: Send & Sync & Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::readable(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::readable_mut(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::writable(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::writable_mut(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::readable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::readable_mut(_): Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::writable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::writable_mut(_): Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::readable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::readable_mut(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::writable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::writable_mut(_): !Send & !Sync & !Unpin); +} diff --git a/tests/io_async_fd.rs b/tests/io_async_fd.rs index d1586bb..dc21e42 100644 --- a/tests/io_async_fd.rs +++ b/tests/io_async_fd.rs @@ -13,7 +13,6 @@ use std::{ task::{Context, Waker}, }; -use nix::errno::Errno; use nix::unistd::{close, read, write}; use futures::{poll, FutureExt}; @@ -56,10 +55,6 @@ impl TestWaker { } } -fn is_blocking(e: &nix::Error) -> bool { - Some(Errno::EAGAIN) == e.as_errno() -} - #[derive(Debug)] struct FileDescriptor { fd: RawFd, @@ -73,11 +68,7 @@ impl AsRawFd for FileDescriptor { impl Read for &FileDescriptor { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match read(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + read(self.fd, buf).map_err(io::Error::from) } } @@ -89,11 +80,7 @@ impl Read for FileDescriptor { impl Write for &FileDescriptor { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - match write(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + write(self.fd, buf).map_err(io::Error::from) } fn flush(&mut self) -> io::Result<()> { diff --git a/tests/io_buf_reader.rs b/tests/io_buf_reader.rs index c72c058..0d3f6ba 100644 --- a/tests/io_buf_reader.rs +++ b/tests/io_buf_reader.rs @@ -8,9 +8,11 @@ use std::cmp; use std::io::{self, Cursor}; use std::pin::Pin; use tokio::io::{ - AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, BufReader, - ReadBuf, SeekFrom, + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWriteExt, + BufReader, ReadBuf, SeekFrom, }; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready}; macro_rules! run_fill_buf { ($reader:expr) => {{ @@ -348,3 +350,30 @@ async fn maybe_pending_seek() { Pin::new(&mut reader).consume(1); assert_eq!(reader.seek(SeekFrom::Current(-2)).await.unwrap(), 3); } + +// This tests the AsyncBufReadExt::fill_buf wrapper. +#[tokio::test] +async fn test_fill_buf_wrapper() { + let (mut write, read) = tokio::io::duplex(16); + + let mut read = BufReader::new(read); + write.write_all(b"hello world").await.unwrap(); + + assert_eq!(read.fill_buf().await.unwrap(), b"hello world"); + read.consume(b"hello ".len()); + assert_eq!(read.fill_buf().await.unwrap(), b"world"); + assert_eq!(read.fill_buf().await.unwrap(), b"world"); + read.consume(b"world".len()); + + let mut fill = spawn(read.fill_buf()); + assert_pending!(fill.poll()); + + write.write_all(b"foo bar").await.unwrap(); + assert_eq!(assert_ready!(fill.poll()).unwrap(), b"foo bar"); + drop(fill); + + drop(write); + assert_eq!(read.fill_buf().await.unwrap(), b"foo bar"); + read.consume(b"foo bar".len()); + assert_eq!(read.fill_buf().await.unwrap(), b""); +} diff --git a/tests/io_buf_writer.rs b/tests/io_buf_writer.rs index 6f4f10a..47a0d46 100644 --- a/tests/io_buf_writer.rs +++ b/tests/io_buf_writer.rs @@ -8,6 +8,17 @@ use std::io::{self, Cursor}; use std::pin::Pin; use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom}; +use futures::future; +use tokio_test::assert_ok; + +use std::cmp; +use std::io::IoSlice; + +mod support { + pub(crate) mod io_vec; +} +use support::io_vec::IoBufs; + struct MaybePending { inner: Vec<u8>, ready: bool, @@ -47,6 +58,14 @@ impl AsyncWrite for MaybePending { } } +async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize> +where + W: AsyncWrite + Unpin, +{ + let mut writer = Pin::new(writer); + future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await +} + #[tokio::test] async fn buf_writer() { let mut writer = BufWriter::with_capacity(2, Vec::new()); @@ -249,3 +268,270 @@ async fn maybe_pending_buf_writer_seek() { &[0, 1, 8, 9, 4, 5, 6, 7] ); } + +struct MockWriter { + data: Vec<u8>, + write_len: usize, + vectored: bool, +} + +impl MockWriter { + fn new(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: false, + } + } + + fn vectored(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: true, + } + } + + fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize { + let len = cmp::min(buf.len(), limit); + self.data.extend_from_slice(&buf[..len]); + len + } +} + +impl AsyncWrite for MockWriter { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + let this = self.get_mut(); + let n = this.write_up_to(buf, this.write_len); + Ok(n).into() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<Result<usize, io::Error>> { + let this = self.get_mut(); + let mut total_written = 0; + for buf in bufs { + let n = this.write_up_to(buf, this.write_len - total_written); + total_written += n; + if total_written == this.write_len { + break; + } + } + Ok(total_written).into() + } + + fn is_write_vectored(&self) -> bool { + self.vectored + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ok(()).into() + } +} + +#[tokio::test] +async fn write_vectored_empty_on_non_vectored() { + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_empty_on_vectored() { + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_basic_on_non_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_basic_on_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_large_total_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 8); + assert!(w.buffer() == &msg[..8]); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 3); + assert!(w.get_ref().data.as_slice() == &msg[..8]); + assert!(w.buffer() == &msg[8..]); +} + +#[tokio::test] +async fn write_vectored_large_total_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 10); + assert!(w.buffer().is_empty()); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 1); + assert!(w.get_ref().data.as_slice() == &msg[..10]); + assert!(w.buffer() == &msg[10..]); +} + +struct VectoredWriteHarness { + writer: BufWriter<MockWriter>, + buf_capacity: usize, +} + +impl VectoredWriteHarness { + fn new(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)), + buf_capacity, + } + } + + fn with_vectored_backend(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)), + buf_capacity, + } + } + + async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize { + let mut total_written = 0; + while !io_vec.is_empty() { + let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await); + assert!(n != 0); + assert!(self.writer.buffer().len() <= self.buf_capacity); + total_written += n; + io_vec = io_vec.advance(n); + } + total_written + } + + async fn flush(&mut self) -> &[u8] { + assert_ok!(self.writer.flush().await); + &self.writer.get_ref().data + } +} + +#[tokio::test] +async fn write_vectored_odd_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_odd_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} diff --git a/tests/io_copy.rs b/tests/io_copy.rs index 9ed7995..005e170 100644 --- a/tests/io_copy.rs +++ b/tests/io_copy.rs @@ -1,7 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{self, AsyncRead, ReadBuf}; +use bytes::BytesMut; +use futures::ready; +use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_test::assert_ok; use std::pin::Pin; @@ -34,3 +36,52 @@ async fn copy() { assert_eq!(n, 11); assert_eq!(wr, b"hello world"); } + +#[tokio::test] +async fn proxy() { + struct BufferedWd { + buf: BytesMut, + writer: io::DuplexStream, + } + + impl AsyncWrite for BufferedWd { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.get_mut().buf.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let this = self.get_mut(); + + while !this.buf.is_empty() { + let n = ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buf))?; + let _ = this.buf.split_to(n); + } + + Pin::new(&mut this.writer).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.writer).poll_shutdown(cx) + } + } + + let (rd, wd) = io::duplex(1024); + let mut rd = rd.take(1024); + let mut wd = BufferedWd { + buf: BytesMut::new(), + writer: wd, + }; + + // write start bytes + assert_ok!(wd.write_all(&[0x42; 512]).await); + assert_ok!(wd.flush().await); + + let n = assert_ok!(io::copy(&mut rd, &mut wd).await); + + assert_eq!(n, 1024); +} diff --git a/tests/io_copy_bidirectional.rs b/tests/io_copy_bidirectional.rs index 17c0597..0e82b29 100644 --- a/tests/io_copy_bidirectional.rs +++ b/tests/io_copy_bidirectional.rs @@ -26,7 +26,7 @@ async fn block_write(s: &mut TcpStream) -> usize { result = s.write(&BUF) => { copied += result.expect("write error") }, - _ = tokio::time::sleep(Duration::from_millis(100)) => { + _ = tokio::time::sleep(Duration::from_millis(10)) => { break; } } @@ -42,7 +42,7 @@ where { // We run the test twice, with streams passed to copy_bidirectional in // different orders, in order to ensure that the two arguments are - // interchangable. + // interchangeable. let (a, mut a1) = make_socketpair().await; let (b, mut b1) = make_socketpair().await; diff --git a/tests/io_split.rs b/tests/io_split.rs index db168e9..a012166 100644 --- a/tests/io_split.rs +++ b/tests/io_split.rs @@ -50,10 +50,10 @@ fn is_send_and_sync() { fn split_stream_id() { let (r1, w1) = split(RW); let (r2, w2) = split(RW); - assert_eq!(r1.is_pair_of(&w1), true); - assert_eq!(r1.is_pair_of(&w2), false); - assert_eq!(r2.is_pair_of(&w2), true); - assert_eq!(r2.is_pair_of(&w1), false); + assert!(r1.is_pair_of(&w1)); + assert!(!r1.is_pair_of(&w2)); + assert!(r2.is_pair_of(&w2)); + assert!(!r2.is_pair_of(&w1)); } #[test] diff --git a/tests/io_write_all_buf.rs b/tests/io_write_all_buf.rs index b49a58e..7c8b619 100644 --- a/tests/io_write_all_buf.rs +++ b/tests/io_write_all_buf.rs @@ -52,7 +52,7 @@ async fn write_all_buf() { assert_eq!(wr.buf, b"helloworld"[..]); // expect 4 writes, [hell],[o],[worl],[d] assert_eq!(wr.cnt, 4); - assert_eq!(buf.has_remaining(), false); + assert!(!buf.has_remaining()); } #[tokio::test] diff --git a/tests/macros_select.rs b/tests/macros_select.rs index a089602..4da88fb 100644 --- a/tests/macros_select.rs +++ b/tests/macros_select.rs @@ -360,7 +360,21 @@ async fn use_future_in_if_condition() { use tokio::time::{self, Duration}; tokio::select! { - _ = time::sleep(Duration::from_millis(50)), if false => { + _ = time::sleep(Duration::from_millis(10)), if false => { + panic!("if condition ignored") + } + _ = async { 1u32 } => { + } + } +} + +#[tokio::test] +async fn use_future_in_if_condition_biased() { + use tokio::time::{self, Duration}; + + tokio::select! { + biased; + _ = time::sleep(Duration::from_millis(10)), if false => { panic!("if condition ignored") } _ = async { 1u32 } => { @@ -456,10 +470,7 @@ async fn require_mutable(_: &mut i32) {} async fn async_noop() {} async fn async_never() -> ! { - use tokio::time::Duration; - loop { - tokio::time::sleep(Duration::from_millis(10)).await; - } + futures::future::pending().await } // From https://github.com/tokio-rs/tokio/issues/2857 diff --git a/tests/named_pipe.rs b/tests/named_pipe.rs index 3f26767..2055c3c 100644 --- a/tests/named_pipe.rs +++ b/tests/named_pipe.rs @@ -126,7 +126,7 @@ async fn test_named_pipe_multi_client() -> io::Result<()> { } // Wait for a named pipe to become available. - time::sleep(Duration::from_millis(50)).await; + time::sleep(Duration::from_millis(10)).await; }; let mut client = BufReader::new(client); @@ -148,6 +148,185 @@ async fn test_named_pipe_multi_client() -> io::Result<()> { Ok(()) } +#[tokio::test] +async fn test_named_pipe_multi_client_ready() -> io::Result<()> { + use tokio::io::Interest; + + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-multi-client-ready"; + const N: usize = 10; + + // The first server needs to be constructed early so that clients can + // be correctly connected. Otherwise calling .wait will cause the client to + // error. + let mut server = ServerOptions::new().create(PIPE_NAME)?; + + let server = tokio::spawn(async move { + for _ in 0..N { + // Wait for client to connect. + server.connect().await?; + + let inner_server = server; + + // Construct the next server to be connected before sending the one + // we already have of onto a task. This ensures that the server + // isn't closed (after it's done in the task) before a new one is + // available. Otherwise the client might error with + // `io::ErrorKind::NotFound`. + server = ServerOptions::new().create(PIPE_NAME)?; + + let _ = tokio::spawn(async move { + let server = inner_server; + + { + let mut read_buf = [0u8; 5]; + let mut read_buf_cursor = 0; + + loop { + server.readable().await?; + + let buf = &mut read_buf[read_buf_cursor..]; + + match server.try_read(buf) { + Ok(n) => { + read_buf_cursor += n; + + if read_buf_cursor == read_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + }; + + { + let write_buf = b"pong\n"; + let mut write_buf_cursor = 0; + + loop { + server.writable().await?; + let buf = &write_buf[write_buf_cursor..]; + + match server.try_write(buf) { + Ok(n) => { + write_buf_cursor += n; + + if write_buf_cursor == write_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + } + + Ok::<_, io::Error>(()) + }); + } + + Ok::<_, io::Error>(()) + }); + + let mut clients = Vec::new(); + + for _ in 0..N { + clients.push(tokio::spawn(async move { + // This showcases a generic connect loop. + // + // We immediately try to create a client, if it's not found or the + // pipe is busy we use the specialized wait function on the client + // builder. + let client = loop { + match ClientOptions::new().open(PIPE_NAME) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + + // Wait for a named pipe to become available. + time::sleep(Duration::from_millis(10)).await; + }; + + let mut read_buf = [0u8; 5]; + let mut read_buf_cursor = 0; + let write_buf = b"ping\n"; + let mut write_buf_cursor = 0; + + loop { + let mut interest = Interest::READABLE; + if write_buf_cursor < write_buf.len() { + interest |= Interest::WRITABLE; + } + + let ready = client.ready(interest).await?; + + if ready.is_readable() { + let buf = &mut read_buf[read_buf_cursor..]; + + match client.try_read(buf) { + Ok(n) => { + read_buf_cursor += n; + + if read_buf_cursor == read_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + + if ready.is_writable() { + let buf = &write_buf[write_buf_cursor..]; + + if buf.is_empty() { + continue; + } + + match client.try_write(buf) { + Ok(n) => { + write_buf_cursor += n; + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + } + + let buf = String::from_utf8_lossy(&read_buf).into_owned(); + + Ok::<_, io::Error>(buf) + })); + } + + for client in clients { + let result = client.await?; + assert_eq!(result?, "pong\n"); + } + + server.await??; + Ok(()) +} + // This tests what happens when a client tries to disconnect. #[tokio::test] async fn test_named_pipe_mode_message() -> io::Result<()> { diff --git a/tests/no_rt.rs b/tests/no_rt.rs index 8437b80..6845850 100644 --- a/tests/no_rt.rs +++ b/tests/no_rt.rs @@ -26,7 +26,7 @@ fn panics_when_no_reactor() { async fn timeout_value() { let (_tx, rx) = oneshot::channel::<()>(); - let dur = Duration::from_millis(20); + let dur = Duration::from_millis(10); let _ = timeout(dur, rx).await; } diff --git a/tests/process_arg0.rs b/tests/process_arg0.rs new file mode 100644 index 0000000..4fabea0 --- /dev/null +++ b/tests/process_arg0.rs @@ -0,0 +1,13 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", unix))] + +use tokio::process::Command; + +#[tokio::test] +async fn arg0() { + let mut cmd = Command::new("sh"); + cmd.arg0("test_string").arg("-c").arg("echo $0"); + + let output = cmd.output().await.unwrap(); + assert_eq!(output.stdout, b"test_string\n"); +} diff --git a/tests/process_raw_handle.rs b/tests/process_raw_handle.rs new file mode 100644 index 0000000..727e66d --- /dev/null +++ b/tests/process_raw_handle.rs @@ -0,0 +1,23 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] +#![cfg(windows)] + +use tokio::process::Command; +use winapi::um::processthreadsapi::GetProcessId; + +#[tokio::test] +async fn obtain_raw_handle() { + let mut cmd = Command::new("cmd"); + cmd.kill_on_drop(true); + cmd.arg("/c"); + cmd.arg("pause"); + + let child = cmd.spawn().unwrap(); + + let orig_id = child.id().expect("missing id"); + assert!(orig_id > 0); + + let handle = child.raw_handle().expect("process stopped"); + let handled_id = unsafe { GetProcessId(handle as _) }; + assert_eq!(handled_id, orig_id); +} diff --git a/tests/support/io_vec.rs b/tests/support/io_vec.rs new file mode 100644 index 0000000..4ea47c7 --- /dev/null +++ b/tests/support/io_vec.rs @@ -0,0 +1,45 @@ +use std::io::IoSlice; +use std::ops::Deref; +use std::slice; + +pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]); + +impl<'a, 'b> IoBufs<'a, 'b> { + pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self { + IoBufs(slices) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> { + let mut to_remove = 0; + let mut remaining_len = n; + for slice in self.0.iter() { + if remaining_len < slice.len() { + break; + } else { + remaining_len -= slice.len(); + to_remove += 1; + } + } + self.0 = self.0.split_at_mut(to_remove).1; + if let Some(slice) = self.0.first_mut() { + let tail = &slice[remaining_len..]; + // Safety: recasts slice to the original lifetime + let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) }; + *slice = IoSlice::new(tail); + } else if remaining_len != 0 { + panic!("advance past the end of the slice vector"); + } + self + } +} + +impl<'a, 'b> Deref for IoBufs<'a, 'b> { + type Target = [IoSlice<'a>]; + fn deref(&self) -> &[IoSlice<'a>] { + self.0 + } +} diff --git a/tests/support/mock_file.rs b/tests/support/mock_file.rs deleted file mode 100644 index 1ce326b..0000000 --- a/tests/support/mock_file.rs +++ /dev/null @@ -1,295 +0,0 @@ -#![allow(clippy::unnecessary_operation)] - -use std::collections::VecDeque; -use std::fmt; -use std::fs::{Metadata, Permissions}; -use std::io; -use std::io::prelude::*; -use std::io::SeekFrom; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; - -pub struct File { - shared: Arc<Mutex<Shared>>, -} - -pub struct Handle { - shared: Arc<Mutex<Shared>>, -} - -struct Shared { - calls: VecDeque<Call>, -} - -#[derive(Debug)] -enum Call { - Read(io::Result<Vec<u8>>), - Write(io::Result<Vec<u8>>), - Seek(SeekFrom, io::Result<u64>), - SyncAll(io::Result<()>), - SyncData(io::Result<()>), - SetLen(u64, io::Result<()>), -} - -impl Handle { - pub fn read(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Read(Ok(data.to_owned()))); - self - } - - pub fn read_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Read(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn write(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Write(Ok(data.to_owned()))); - self - } - - pub fn write_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Write(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn seek_start_ok(&self, offset: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Start(offset), Ok(offset))); - self - } - - pub fn seek_current_ok(&self, offset: i64, ret: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Current(offset), Ok(ret))); - self - } - - pub fn sync_all(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncAll(Ok(()))); - self - } - - pub fn sync_all_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncAll(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn sync_data(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncData(Ok(()))); - self - } - - pub fn sync_data_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncData(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn set_len(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SetLen(size, Ok(()))); - self - } - - pub fn set_len_err(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SetLen(size, Err(io::ErrorKind::Other.into()))); - self - } - - pub fn remaining(&self) -> usize { - let s = self.shared.lock().unwrap(); - s.calls.len() - } -} - -impl Drop for Handle { - fn drop(&mut self) { - if !std::thread::panicking() { - let s = self.shared.lock().unwrap(); - assert_eq!(0, s.calls.len()); - } - } -} - -impl File { - pub fn open(_: PathBuf) -> io::Result<File> { - unimplemented!(); - } - - pub fn create(_: PathBuf) -> io::Result<File> { - unimplemented!(); - } - - pub fn mock() -> (Handle, File) { - let shared = Arc::new(Mutex::new(Shared { - calls: VecDeque::new(), - })); - - let handle = Handle { - shared: shared.clone(), - }; - let file = File { shared }; - - (handle, file) - } - - pub fn sync_all(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncAll(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn sync_data(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncData(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn set_len(&self, size: u64) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SetLen(arg, ret)) => { - assert_eq!(arg, size); - ret - } - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn metadata(&self) -> io::Result<Metadata> { - unimplemented!(); - } - - pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()> { - unimplemented!(); - } - - pub fn try_clone(&self) -> io::Result<Self> { - unimplemented!(); - } -} - -impl Read for &'_ File { - fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Read(Ok(data))) => { - assert!(dst.len() >= data.len()); - assert!(dst.len() <= 16 * 1024, "actual = {}", dst.len()); // max buffer - - &mut dst[..data.len()].copy_from_slice(&data); - Ok(data.len()) - } - Some(Read(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was a read", op), - None => panic!("did not expect call"), - } - } -} - -impl Write for &'_ File { - fn write(&mut self, src: &[u8]) -> io::Result<usize> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Write(Ok(data))) => { - assert_eq!(src, &data[..]); - Ok(src.len()) - } - Some(Write(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was write", op), - None => panic!("did not expect call"), - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl Seek for &'_ File { - fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Seek(expect, res)) => { - assert_eq!(expect, pos); - res - } - Some(op) => panic!("expected call {:?}; was `seek`", op), - None => panic!("did not expect call; was `seek`"), - } - } -} - -impl fmt::Debug for File { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("mock::File").finish() - } -} - -#[cfg(unix)] -impl std::os::unix::io::AsRawFd for File { - fn as_raw_fd(&self) -> std::os::unix::io::RawFd { - unimplemented!(); - } -} - -#[cfg(unix)] -impl std::os::unix::io::FromRawFd for File { - unsafe fn from_raw_fd(_: std::os::unix::io::RawFd) -> Self { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::AsRawHandle for File { - fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::FromRawHandle for File { - unsafe fn from_raw_handle(_: std::os::windows::io::RawHandle) -> Self { - unimplemented!(); - } -} diff --git a/tests/support/mock_pool.rs b/tests/support/mock_pool.rs deleted file mode 100644 index e1fdb42..0000000 --- a/tests/support/mock_pool.rs +++ /dev/null @@ -1,66 +0,0 @@ -use tokio::sync::oneshot; - -use std::cell::RefCell; -use std::collections::VecDeque; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -thread_local! { - static QUEUE: RefCell<VecDeque<Box<dyn FnOnce() + Send>>> = RefCell::new(VecDeque::new()) -} - -#[derive(Debug)] -pub(crate) struct Blocking<T> { - rx: oneshot::Receiver<T>, -} - -pub(crate) fn run<F, R>(f: F) -> Blocking<R> -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - let (tx, rx) = oneshot::channel(); - let task = Box::new(move || { - let _ = tx.send(f()); - }); - - QUEUE.with(|cell| cell.borrow_mut().push_back(task)); - - Blocking { rx } -} - -impl<T> Future for Blocking<T> { - type Output = Result<T, io::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - use std::task::Poll::*; - - match Pin::new(&mut self.rx).poll(cx) { - Ready(Ok(v)) => Ready(Ok(v)), - Ready(Err(e)) => panic!("error = {:?}", e), - Pending => Pending, - } - } -} - -pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> -where - F: FnOnce() -> io::Result<T> + Send + 'static, - T: Send + 'static, -{ - run(f).await? -} - -pub(crate) fn len() -> usize { - QUEUE.with(|cell| cell.borrow().len()) -} - -pub(crate) fn run_one() { - let task = QUEUE - .with(|cell| cell.borrow_mut().pop_front()) - .expect("expected task to run, but none ready"); - - task(); -} diff --git a/tests/sync_mutex.rs b/tests/sync_mutex.rs index 0ddb203..090db94 100644 --- a/tests/sync_mutex.rs +++ b/tests/sync_mutex.rs @@ -139,12 +139,12 @@ fn try_lock() { let m: Mutex<usize> = Mutex::new(0); { let g1 = m.try_lock(); - assert_eq!(g1.is_ok(), true); + assert!(g1.is_ok()); let g2 = m.try_lock(); - assert_eq!(g2.is_ok(), false); + assert!(!g2.is_ok()); } let g3 = m.try_lock(); - assert_eq!(g3.is_ok(), true); + assert!(g3.is_ok()); } #[tokio::test] diff --git a/tests/sync_mutex_owned.rs b/tests/sync_mutex_owned.rs index 0f1399c..898bf35 100644 --- a/tests/sync_mutex_owned.rs +++ b/tests/sync_mutex_owned.rs @@ -106,12 +106,12 @@ fn try_lock_owned() { let m: Arc<Mutex<usize>> = Arc::new(Mutex::new(0)); { let g1 = m.clone().try_lock_owned(); - assert_eq!(g1.is_ok(), true); + assert!(g1.is_ok()); let g2 = m.clone().try_lock_owned(); - assert_eq!(g2.is_ok(), false); + assert!(!g2.is_ok()); } let g3 = m.try_lock_owned(); - assert_eq!(g3.is_ok(), true); + assert!(g3.is_ok()); } #[tokio::test] diff --git a/tests/sync_once_cell.rs b/tests/sync_once_cell.rs index 60f50d2..18eaf93 100644 --- a/tests/sync_once_cell.rs +++ b/tests/sync_once_cell.rs @@ -266,3 +266,9 @@ fn drop_into_inner_new_with() { let count = NUM_DROPS.load(Ordering::Acquire); assert!(count == 1); } + +#[test] +fn from() { + let cell = OnceCell::from(2); + assert_eq!(*cell.get().unwrap(), 2); +} diff --git a/tests/sync_rwlock.rs b/tests/sync_rwlock.rs index e12052b..7d05086 100644 --- a/tests/sync_rwlock.rs +++ b/tests/sync_rwlock.rs @@ -50,8 +50,8 @@ fn read_exclusive_pending() { assert_pending!(t2.poll()); } -// If the max shared access is reached and subsquent shared access is pending -// should be made available when one of the shared acesses is dropped +// If the max shared access is reached and subsequent shared access is pending +// should be made available when one of the shared accesses is dropped #[test] fn exhaust_reading() { let rwlock = RwLock::with_max_readers(100, 1024); diff --git a/tests/sync_watch.rs b/tests/sync_watch.rs index 9dcb0c5..a2a276d 100644 --- a/tests/sync_watch.rs +++ b/tests/sync_watch.rs @@ -169,3 +169,20 @@ fn poll_close() { assert!(tx.send("two").is_err()); } + +#[test] +fn borrow_and_update() { + let (tx, mut rx) = watch::channel("one"); + + tx.send("two").unwrap(); + assert_ready!(spawn(rx.changed()).poll()).unwrap(); + assert_pending!(spawn(rx.changed()).poll()); + + tx.send("three").unwrap(); + assert_eq!(*rx.borrow_and_update(), "three"); + assert_pending!(spawn(rx.changed()).poll()); + + drop(tx); + assert_eq!(*rx.borrow_and_update(), "three"); + assert_ready!(spawn(rx.changed()).poll()).unwrap_err(); +} diff --git a/tests/task_abort.rs b/tests/task_abort.rs index 1d72ac3..06c61dc 100644 --- a/tests/task_abort.rs +++ b/tests/task_abort.rs @@ -1,11 +1,25 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use std::sync::Arc; +use std::thread::sleep; +use tokio::time::Duration; + +use tokio::runtime::Builder; + +struct PanicOnDrop; + +impl Drop for PanicOnDrop { + fn drop(&mut self) { + panic!("Well what did you expect would happen..."); + } +} + /// Checks that a suspended task can be aborted without panicking as reported in /// issue #3157: <https://github.com/tokio-rs/tokio/issues/3157>. #[test] fn test_abort_without_panic_3157() { - let rt = tokio::runtime::Builder::new_multi_thread() + let rt = Builder::new_multi_thread() .enable_time() .worker_threads(1) .build() @@ -14,11 +28,11 @@ fn test_abort_without_panic_3157() { rt.block_on(async move { let handle = tokio::spawn(async move { println!("task started"); - tokio::time::sleep(std::time::Duration::new(100, 0)).await + tokio::time::sleep(Duration::new(100, 0)).await }); // wait for task to sleep. - tokio::time::sleep(std::time::Duration::new(1, 0)).await; + tokio::time::sleep(Duration::from_millis(10)).await; handle.abort(); let _ = handle.await; @@ -41,9 +55,7 @@ fn test_abort_without_panic_3662() { } } - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(); + let rt = Builder::new_current_thread().build().unwrap(); rt.block_on(async move { let drop_flag = Arc::new(AtomicBool::new(false)); @@ -62,18 +74,16 @@ fn test_abort_without_panic_3662() { // This runs in a separate thread so it doesn't have immediate // thread-local access to the executor. It does however transition // the underlying task to be completed, which will cause it to be - // dropped (in this thread no less). + // dropped (but not in this thread). assert!(!drop_flag2.load(Ordering::SeqCst)); j.abort(); - // TODO: is this guaranteed at this point? - // assert!(drop_flag2.load(Ordering::SeqCst)); j }) .join() .unwrap(); - assert!(drop_flag.load(Ordering::SeqCst)); let result = task.await; + assert!(drop_flag.load(Ordering::SeqCst)); assert!(result.unwrap_err().is_cancelled()); // Note: We do the following to trigger a deferred task cleanup. @@ -82,7 +92,7 @@ fn test_abort_without_panic_3662() { // `Inner::block_on` of `basic_scheduler.rs`. // // We cause the cleanup to happen by having a poll return Pending once - // so that the scheduler can go into the "auxilliary tasks" mode, at + // so that the scheduler can go into the "auxiliary tasks" mode, at // which point the task is removed from the scheduler. let i = tokio::spawn(async move { tokio::task::yield_now().await; @@ -91,3 +101,126 @@ fn test_abort_without_panic_3662() { i.await.unwrap(); }); } + +/// Checks that a suspended LocalSet task can be aborted from a remote thread +/// without panicking and without running the tasks destructor on the wrong thread. +/// <https://github.com/tokio-rs/tokio/issues/3929> +#[test] +fn remote_abort_local_set_3929() { + struct DropCheck { + created_on: std::thread::ThreadId, + not_send: std::marker::PhantomData<*const ()>, + } + + impl DropCheck { + fn new() -> Self { + Self { + created_on: std::thread::current().id(), + not_send: std::marker::PhantomData, + } + } + } + impl Drop for DropCheck { + fn drop(&mut self) { + if std::thread::current().id() != self.created_on { + panic!("non-Send value dropped in another thread!"); + } + } + } + + let rt = Builder::new_current_thread().build().unwrap(); + let local = tokio::task::LocalSet::new(); + + let check = DropCheck::new(); + let jh = local.spawn_local(async move { + futures::future::pending::<()>().await; + drop(check); + }); + + let jh2 = std::thread::spawn(move || { + sleep(Duration::from_millis(10)); + jh.abort(); + }); + + rt.block_on(local); + jh2.join().unwrap(); +} + +/// Checks that a suspended task can be aborted even if the `JoinHandle` is immediately dropped. +/// issue #3964: <https://github.com/tokio-rs/tokio/issues/3964>. +#[test] +fn test_abort_wakes_task_3964() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _notify_dropped = notify_dropped; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(Duration::from_millis(10)).await; + + // Check that the Arc has been dropped. + assert!(weak_notify_dropped.upgrade().is_none()); + }); +} + +/// Checks that aborting a task whose destructor panics does not allow the +/// panic to escape the task. +#[test] +#[cfg(not(target_os = "android"))] +fn test_abort_task_that_panics_on_drop_contained() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(Duration::from_millis(10)).await; + }); +} + +/// Checks that aborting a task whose destructor panics has the expected result. +#[test] +#[cfg(not(target_os = "android"))] +fn test_abort_task_that_panics_on_drop_returned() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + assert!(handle.await.unwrap_err().is_panic()); + }); +} diff --git a/tests/task_builder.rs b/tests/task_builder.rs new file mode 100644 index 0000000..1499abf --- /dev/null +++ b/tests/task_builder.rs @@ -0,0 +1,67 @@ +#[cfg(all(tokio_unstable, feature = "tracing"))] +mod tests { + use std::rc::Rc; + use tokio::{ + task::{Builder, LocalSet}, + test, + }; + + #[test] + async fn spawn_with_name() { + let result = Builder::new() + .name("name") + .spawn(async { "task executed" }) + .await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_blocking_with_name() { + let result = Builder::new() + .name("name") + .spawn_blocking(|| "task executed") + .await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_local_with_name() { + let unsend_data = Rc::new("task executed"); + let result = LocalSet::new() + .run_until(async move { + Builder::new() + .name("name") + .spawn_local(async move { unsend_data }) + .await + }) + .await; + + assert_eq!(*result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_without_name() { + let result = Builder::new().spawn(async { "task executed" }).await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_blocking_without_name() { + let result = Builder::new().spawn_blocking(|| "task executed").await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_local_without_name() { + let unsend_data = Rc::new("task executed"); + let result = LocalSet::new() + .run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await }) + .await; + + assert_eq!(*result.unwrap(), "task executed"); + } +} diff --git a/tests/task_local_set.rs b/tests/task_local_set.rs index 8513609..f8a35d0 100644 --- a/tests/task_local_set.rs +++ b/tests/task_local_set.rs @@ -67,11 +67,11 @@ async fn localset_future_timers() { let local = LocalSet::new(); local.spawn_local(async move { - time::sleep(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(5)).await; RAN1.store(true, Ordering::SeqCst); }); local.spawn_local(async move { - time::sleep(Duration::from_millis(20)).await; + time::sleep(Duration::from_millis(10)).await; RAN2.store(true, Ordering::SeqCst); }); local.await; @@ -299,9 +299,7 @@ fn drop_cancels_tasks() { let _rc2 = rc2; started_tx.send(()).unwrap(); - loop { - time::sleep(Duration::from_secs(3600)).await; - } + futures::future::pending::<()>().await; }); local.block_on(&rt, async { @@ -334,7 +332,7 @@ fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) { // something we can easily make assertions about, we'll run it in a // thread. When the test thread finishes, it will send a message on a // channel to this thread. We'll wait for that message with a fairly - // generous timeout, and if we don't recieve it, we assume the test + // generous timeout, and if we don't receive it, we assume the test // thread has hung. // // Note that it should definitely complete in under a minute, but just @@ -400,13 +398,32 @@ fn local_tasks_wake_join_all() { }); } -#[tokio::test] -async fn local_tasks_are_polled_after_tick() { +#[test] +fn local_tasks_are_polled_after_tick() { + // This test depends on timing, so we run it up to five times. + for _ in 0..4 { + let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner); + if res.is_ok() { + // success + return; + } + } + + // Test failed 4 times. Try one more time without catching panics. If it + // fails again, the test fails. + local_tasks_are_polled_after_tick_inner(); +} + +#[tokio::main(flavor = "current_thread")] +async fn local_tasks_are_polled_after_tick_inner() { // Reproduces issues #1899 and #1900 static RX1: AtomicUsize = AtomicUsize::new(0); static RX2: AtomicUsize = AtomicUsize::new(0); - static EXPECTED: usize = 500; + const EXPECTED: usize = 500; + + RX1.store(0, SeqCst); + RX2.store(0, SeqCst); let (tx, mut rx) = mpsc::unbounded_channel(); @@ -416,7 +433,7 @@ async fn local_tasks_are_polled_after_tick() { .run_until(async { let task2 = task::spawn(async move { // Wait a bit - time::sleep(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(10)).await; let mut oneshots = Vec::with_capacity(EXPECTED); @@ -427,13 +444,13 @@ async fn local_tasks_are_polled_after_tick() { tx.send(oneshot_rx).unwrap(); } - time::sleep(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(10)).await; for tx in oneshots.drain(..) { tx.send(()).unwrap(); } - time::sleep(Duration::from_millis(300)).await; + time::sleep(Duration::from_millis(20)).await; let rx1 = RX1.load(SeqCst); let rx2 = RX2.load(SeqCst); println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2); diff --git a/tests/tcp_into_split.rs b/tests/tcp_into_split.rs index b4bb2ee..2e06643 100644 --- a/tests/tcp_into_split.rs +++ b/tests/tcp_into_split.rs @@ -116,7 +116,7 @@ async fn drop_write() -> Result<()> { // drop it while the read is in progress std::thread::spawn(move || { - thread::sleep(std::time::Duration::from_millis(50)); + thread::sleep(std::time::Duration::from_millis(10)); drop(write_half); }); diff --git a/tests/time_interval.rs b/tests/time_interval.rs index a3c7f08..5f7bf55 100644 --- a/tests/time_interval.rs +++ b/tests/time_interval.rs @@ -1,56 +1,173 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::time::{self, Duration, Instant}; +use tokio::time::{self, Duration, Instant, MissedTickBehavior}; use tokio_test::{assert_pending, assert_ready_eq, task}; -use std::future::Future; use std::task::Poll; +// Takes the `Interval` task, `start` variable, and optional time deltas +// For each time delta, it polls the `Interval` and asserts that the result is +// equal to `start` + the specific time delta. Then it asserts that the +// `Interval` is pending. +macro_rules! check_interval_poll { + ($i:ident, $start:ident, $($delta:expr),*$(,)?) => { + $( + assert_ready_eq!(poll_next(&mut $i), $start + ms($delta)); + )* + assert_pending!(poll_next(&mut $i)); + }; + ($i:ident, $start:ident) => { + check_interval_poll!($i, $start,); + }; +} + #[tokio::test] #[should_panic] async fn interval_zero_duration() { let _ = time::interval_at(Instant::now(), ms(0)); } -#[tokio::test] -async fn usage() { - time::pause(); +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work | work | work -| work -----| +// Poll behavior: | | | | | | | | +// | | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | | +// Pending | Ready(s + 3p) | | | +// Ready(s + p) Ready(s + 4p) | | +// Ready(s + 5p) | +// Ready(s + 6p) +#[tokio::test(start_paused = true)] +async fn burst() { + let start = Instant::now(); + + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer + time::advance(ms(1)).await; + + let mut i = task::spawn(time::interval_at(start, ms(300))); + + check_interval_poll!(i, start, 0); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + time::advance(ms(200)).await; + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600, 900); + + time::advance(ms(200)).await; + check_interval_poll!(i, start); + + time::advance(ms(100)).await; + check_interval_poll!(i, start, 1200); + + time::advance(ms(250)).await; + check_interval_poll!(i, start, 1500); + + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1800); +} +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work -----| work -----| work -----| +// Poll behavior: | | | | | | | | +// | | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | | +// Pending | Pending | | | +// Ready(s + p) Ready(s + 2p + d) | | +// Ready(s + 3p + d) | +// Ready(s + 4p + d) +#[tokio::test(start_paused = true)] +async fn delay() { let start = Instant::now(); - // TODO: Skip this + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer time::advance(ms(1)).await; let mut i = task::spawn(time::interval_at(start, ms(300))); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); - assert_ready_eq!(poll_next(&mut i), start); - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start, 0); time::advance(ms(100)).await; - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start); time::advance(ms(200)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(300)); - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + // We have to add one here for the same reason as is above. + // Because `Interval` has reset its timer according to `Instant::now()`, + // we have to go forward 1 more millisecond than is expected so that the + // runtime realizes that it's time to resolve the timer. + time::advance(ms(201)).await; + // We add one because when using the `Delay` behavior, `Interval` + // adds the `period` from `Instant::now()`, which will always be off by one + // because we have to advance time by 1 (see above). + check_interval_poll!(i, start, 1251); + + time::advance(ms(300)).await; + // Again, we add one. + check_interval_poll!(i, start, 1551); + + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1851); +} + +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work ---| work -----| work -----| +// Poll behavior: | | | | | | | +// | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | +// Pending | Ready(s + 4p) | | +// Ready(s + p) Ready(s + 5p) | +// Ready(s + 6p) +#[tokio::test(start_paused = true)] +async fn skip() { + let start = Instant::now(); + + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer + time::advance(ms(1)).await; + + let mut i = task::spawn(time::interval_at(start, ms(300))); + i.set_missed_tick_behavior(MissedTickBehavior::Skip); + + check_interval_poll!(i, start, 0); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + time::advance(ms(200)).await; + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600); + + time::advance(ms(250)).await; + check_interval_poll!(i, start, 1200); - time::advance(ms(400)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(600)); - assert_pending!(poll_next(&mut i)); + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1500); - time::advance(ms(500)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(900)); - assert_ready_eq!(poll_next(&mut i), start + ms(1200)); - assert_pending!(poll_next(&mut i)); + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1800); } fn poll_next(interval: &mut task::Spawn<time::Interval>) -> Poll<Instant> { - interval.enter(|cx, mut interval| { - tokio::pin! { - let fut = interval.tick(); - } - fut.poll(cx) - }) + interval.enter(|cx, mut interval| interval.poll_tick(cx)) } fn ms(n: u64) -> Duration { diff --git a/tests/time_rt.rs b/tests/time_rt.rs index 0775343..23367be 100644 --- a/tests/time_rt.rs +++ b/tests/time_rt.rs @@ -13,7 +13,7 @@ fn timer_with_threaded_runtime() { let (tx, rx) = mpsc::channel(); rt.spawn(async move { - let when = Instant::now() + Duration::from_millis(100); + let when = Instant::now() + Duration::from_millis(10); sleep_until(when).await; assert!(Instant::now() >= when); @@ -32,7 +32,7 @@ fn timer_with_basic_scheduler() { let (tx, rx) = mpsc::channel(); rt.block_on(async move { - let when = Instant::now() + Duration::from_millis(100); + let when = Instant::now() + Duration::from_millis(10); sleep_until(when).await; assert!(Instant::now() >= when); @@ -67,7 +67,7 @@ async fn starving() { } } - let when = Instant::now() + Duration::from_millis(20); + let when = Instant::now() + Duration::from_millis(10); let starve = Starve(Box::pin(sleep_until(when)), 0); starve.await; @@ -81,7 +81,7 @@ async fn timeout_value() { let (_tx, rx) = oneshot::channel::<()>(); let now = Instant::now(); - let dur = Duration::from_millis(20); + let dur = Duration::from_millis(10); let res = timeout(dur, rx).await; assert!(res.is_err()); diff --git a/tests/time_sleep.rs b/tests/time_sleep.rs index 9c04d22..20477d2 100644 --- a/tests/time_sleep.rs +++ b/tests/time_sleep.rs @@ -24,7 +24,7 @@ async fn immediate_sleep() { async fn is_elapsed() { time::pause(); - let sleep = time::sleep(Duration::from_millis(50)); + let sleep = time::sleep(Duration::from_millis(10)); tokio::pin!(sleep); @@ -349,7 +349,7 @@ async fn drop_from_wake() { assert!( !panicked.load(Ordering::SeqCst), - "paniced when dropping timers" + "panicked when dropping timers" ); #[derive(Clone)] diff --git a/tests/uds_datagram.rs b/tests/uds_datagram.rs index 10314be..4d28468 100644 --- a/tests/uds_datagram.rs +++ b/tests/uds_datagram.rs @@ -87,9 +87,12 @@ async fn try_send_recv_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len()); @@ -291,9 +294,12 @@ async fn try_recv_buf_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len()); diff --git a/tests/uds_stream.rs b/tests/uds_stream.rs index 2754e84..5f1b4cf 100644 --- a/tests/uds_stream.rs +++ b/tests/uds_stream.rs @@ -379,3 +379,33 @@ async fn try_read_buf() -> std::io::Result<()> { Ok(()) } + +// https://github.com/tokio-rs/tokio/issues/3879 +#[tokio::test] +#[cfg(not(target_os = "macos"))] +async fn epollhup() -> io::Result<()> { + let dir = tempfile::Builder::new() + .prefix("tokio-uds-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("connect.sock"); + + let listener = UnixListener::bind(&sock_path)?; + let connect = UnixStream::connect(&sock_path); + tokio::pin!(connect); + + // Poll `connect` once. + poll_fn(|cx| { + use std::future::Future; + + assert_pending!(connect.as_mut().poll(cx)); + Poll::Ready(()) + }) + .await; + + drop(listener); + + let err = connect.await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::ConnectionReset); + Ok(()) +} |