diff options
author | Haibo Huang <hhb@google.com> | 2021-01-14 17:23:22 -0800 |
---|---|---|
committer | Jeff Vander Stoep <jeffv@google.com> | 2021-01-15 20:44:08 +0100 |
commit | 290fc4903cd00fc31d93e0ecd49c402e6833c569 (patch) | |
tree | 4a9646d2ab712bae1ead875992160c7248588daf | |
parent | 84cad6596f48e471881980dcba7df9cb5b4b0139 (diff) | |
download | tokio-290fc4903cd00fc31d93e0ecd49c402e6833c569.tar.gz |
Upgrade rust/crates/tokio to 1.0.2platform-tools-31.0.0
Test: make
Change-Id: Ic48ff709bade266749eac8c146856901ce78da7f
168 files changed, 4906 insertions, 6476 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index 8445b12..0ccdcd3 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,5 +1,5 @@ { "git": { - "sha1": "479c545c20b2cb44a8f09600733adc8c8dcb5aa0" + "sha1": "5d35c907f693e25ba20c3cfb47e0cb1957679019" } } @@ -1,4 +1,4 @@ -// This file is generated by cargo2android.py --device --run --dependencies --features io-util,macros,rt-multi-thread,sync,stream,net,fs,time. +// This file is generated by cargo2android.py --device --run --dependencies --features io-util,macros,rt-multi-thread,sync,net,fs,time. rust_library { name: "libtokio", @@ -9,9 +9,7 @@ rust_library { features: [ "bytes", "fs", - "futures-core", "io-util", - "lazy_static", "libc", "macros", "memchr", @@ -20,8 +18,6 @@ rust_library { "num_cpus", "rt", "rt-multi-thread", - "slab", - "stream", "sync", "time", "tokio-macros", @@ -31,33 +27,27 @@ rust_library { ], rustlibs: [ "libbytes", - "libfutures_core", - "liblazy_static", "liblibc", "libmemchr", "libmio", "libnum_cpus", "libpin_project_lite", - "libslab", ], proc_macros: ["libtokio_macros"], } // dependent_library ["feature_list"] // autocfg-1.0.1 -// bytes-0.6.0 "default,std" +// bytes-1.0.1 "default,std" // cfg-if-0.1.10 -// futures-core-0.3.8 "alloc,default,std" -// lazy_static-1.4.0 -// libc-0.2.81 "align,default,extra_traits,std" -// log-0.4.11 +// libc-0.2.82 "align,default,extra_traits,std" +// log-0.4.13 // memchr-2.3.4 "default,std" -// mio-0.7.6 "default,net,os-ext,os-poll,os-util,tcp,udp,uds" +// mio-0.7.7 "default,net,os-ext,os-poll,os-util,tcp,udp,uds" // num_cpus-1.13.0 -// pin-project-lite-0.2.0 +// pin-project-lite-0.2.4 // proc-macro2-1.0.24 "default,proc-macro" // quote-1.0.8 "default,proc-macro" -// slab-0.4.2 -// syn-1.0.55 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit,visit-mut" -// tokio-macros-0.3.2 +// syn-1.0.58 "clone-impls,default,derive,extra-traits,full,parsing,printing,proc-macro,quote,visit,visit-mut" +// tokio-macros-1.0.0 // unicode-xid-0.2.1 "default" diff --git a/CHANGELOG.md b/CHANGELOG.md index 994dda4..a36212d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,90 @@ +# 1.0.2 (January 14, 2020) + +### Fixed +- io: soundness in `read_to_end` (#3428). + +# 1.0.1 (December 25, 2020) + +This release fixes a soundness hole caused by the combination of `RwLockWriteGuard::map` +and `RwLockWriteGuard::downgrade` by removing the `map` function. This is a breaking +change, but breaking changes are allowed under our semver policy when they are required +to fix a soundness hole. (See [this RFC][semver] for more.) + +Note that we have chosen not to do a deprecation cycle or similar because Tokio 1.0.0 was +released two days ago, and therefore the impact should be minimal. + +Due to the soundness hole, we have also yanked Tokio version 1.0.0. + +### Removed + + - sync: remove `RwLockWriteGuard::map` and `RwLockWriteGuard::try_map` (#3345) + +### Fixed + + - docs: remove stream feature from docs (#3335) + +[semver]: https://github.com/rust-lang/rfcs/blob/master/text/1122-language-semver.md#soundness-changes + +# 1.0.0 (December 23, 2020) + +Commit to the API and long-term support. + +### Fixed +- sync: spurious wakeup in `watch` (#3234). + +### Changed +- io: rename `AsyncFd::with_io()` to `try_io()` (#3306) +- fs: avoid OS specific `*Ext` traits in favor of conditionally defining the fn (#3264). +- fs: `Sleep` is `!Unpin` (#3278). +- net: pass `SocketAddr` by value (#3125). +- net: `TcpStream::poll_peek` takes `ReadBuf` (#3259). +- rt: rename `runtime::Builder::max_threads()` to `max_blocking_threads()` (#3287). +- time: require `current_thread` runtime when calling `time::pause()` (#3289). + +### Removed +- remove `tokio::prelude` (#3299). +- io: remove `AsyncFd::with_poll()` (#3306). +- net: remove `{Tcp,Unix}Stream::shutdown()` in favor of `AsyncWrite::shutdown()` (#3298). +- stream: move all stream utilities to `tokio-stream` until `Stream` is added to + `std` (#3277). +- sync: mpsc `try_recv()` due to unexpected behavior (#3263). +- tracing: make unstable as `tracing-core` is not 1.0 yet (#3266). + +### Added +- fs: `poll_*` fns to `DirEntry` (#3308). +- io: `poll_*` fns to `io::Lines`, `io::Split` (#3308). +- io: `_mut` method variants to `AsyncFd` (#3304). +- net: `poll_*` fns to `UnixDatagram` (#3223). +- net: `UnixStream` readiness and non-blocking ops (#3246). +- sync: `UnboundedReceiver::blocking_recv()` (#3262). +- sync: `watch::Sender::borrow()` (#3269). +- sync: `Semaphore::close()` (#3065). +- sync: `poll_recv` fns to `mpsc::Receiver`, `mpsc::UnboundedReceiver` (#3308). +- time: `poll_tick` fn to `time::Interval` (#3316). + +# 0.3.6 (December 14, 2020) + +### Fixed +- rt: fix deadlock in shutdown (#3228) +- rt: fix panic in task abort when off rt (#3159) +- sync: make `add_permits` panic with usize::MAX >> 3 permits (#3188) +- time: Fix race condition in timer drop (#3229) +- watch: fix spurious wakeup (#3244) + +### Added +- example: add back udp-codec example (#3205) +- net: add `TcpStream::into_std` (#3189) + +# 0.3.5 (November 30, 2020) + +### Fixed +- rt: fix `shutdown_timeout(0)` (#3196). +- time: fixed race condition with small sleeps (#3069). + +### Added +- io: `AsyncFd::with_interest()` (#3167). +- signal: `CtrlC` stream on windows (#3186). + # 0.3.4 (November 18, 2020) ### Fixed @@ -13,11 +13,11 @@ [package] edition = "2018" name = "tokio" -version = "0.3.4" +version = "1.0.2" authors = ["Tokio Contributors <team@tokio.rs>"] description = "An event-driven, non-blocking I/O platform for writing asynchronous I/O\nbacked applications.\n" homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio/0.3.4/tokio/" +documentation = "https://docs.rs/tokio/1.0.2/tokio/" readme = "README.md" keywords = ["io", "async", "non-blocking", "futures"] categories = ["asynchronous", "network-programming"] @@ -28,17 +28,9 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [package.metadata.playground] -features = ["full"] +features = ["full", "test-util"] [dependencies.bytes] -version = "0.6.0" -optional = true - -[dependencies.futures-core] -version = "0.3.0" -optional = true - -[dependencies.lazy_static] -version = "1.4.0" +version = "1.0.0" optional = true [dependencies.memchr] @@ -53,6 +45,10 @@ optional = true version = "1.8.0" optional = true +[dependencies.once_cell] +version = "1.5.2" +optional = true + [dependencies.parking_lot] version = "0.11.0" optional = true @@ -60,19 +56,12 @@ optional = true [dependencies.pin-project-lite] version = "0.2.0" -[dependencies.slab] -version = "0.4.1" -optional = true - [dependencies.tokio-macros] -version = "0.3.0" +version = "1.0.0" optional = true +[dev-dependencies.async-stream] +version = "0.3" -[dependencies.tracing] -version = "0.1.21" -features = ["std"] -optional = true -default-features = false [dev-dependencies.futures] version = "0.3.0" features = ["async-await"] @@ -83,30 +72,37 @@ version = "0.10.0" [dev-dependencies.tempfile] version = "3.1.0" +[dev-dependencies.tokio-stream] +version = "0.1" + [dev-dependencies.tokio-test] -version = "0.3.0" +version = "0.4.0" [build-dependencies.autocfg] version = "1" [features] default = [] fs = [] -full = ["fs", "io-util", "io-std", "macros", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "stream", "sync", "time"] +full = ["fs", "io-util", "io-std", "macros", "net", "parking_lot", "process", "rt", "rt-multi-thread", "signal", "sync", "time"] io-std = [] io-util = ["memchr", "bytes"] macros = ["tokio-macros"] -net = ["lazy_static", "libc", "mio/os-poll", "mio/os-util", "mio/tcp", "mio/udp", "mio/uds"] -process = ["bytes", "lazy_static", "libc", "mio/os-poll", "mio/os-util", "mio/uds", "signal-hook-registry", "winapi/threadpoollegacyapiset"] -rt = ["slab"] +net = ["libc", "mio/os-poll", "mio/os-util", "mio/tcp", "mio/udp", "mio/uds"] +process = ["bytes", "once_cell", "libc", "mio/os-poll", "mio/os-util", "mio/uds", "signal-hook-registry", "winapi/threadpoollegacyapiset"] +rt = [] rt-multi-thread = ["num_cpus", "rt"] -signal = ["lazy_static", "libc", "mio/os-poll", "mio/uds", "mio/os-util", "signal-hook-registry", "winapi/consoleapi"] -stream = ["futures-core"] +signal = ["once_cell", "libc", "mio/os-poll", "mio/uds", "mio/os-util", "signal-hook-registry", "winapi/consoleapi"] sync = [] test-util = [] time = [] [target."cfg(loom)".dev-dependencies.loom] -version = "0.3.5" +version = "0.4" features = ["futures", "checkpoint"] +[target."cfg(tokio_unstable)".dependencies.tracing] +version = "0.1.21" +features = ["std"] +optional = true +default-features = false [target."cfg(unix)".dependencies.libc] version = "0.2.42" optional = true diff --git a/Cargo.toml.orig b/Cargo.toml.orig index c658fc4..f950a28 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -7,13 +7,13 @@ name = "tokio" # - Cargo.toml # - README.md # - Update CHANGELOG.md. -# - Create "v0.3.x" git tag. -version = "0.3.4" +# - Create "v1.0.x" git tag. +version = "1.0.2" edition = "2018" authors = ["Tokio Contributors <team@tokio.rs>"] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/0.3.4/tokio/" +documentation = "https://docs.rs/tokio/1.0.2/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -39,7 +39,6 @@ full = [ "rt", "rt-multi-thread", "signal", - "stream", "sync", "time", ] @@ -50,7 +49,6 @@ io-util = ["memchr", "bytes"] io-std = [] macros = ["tokio-macros"] net = [ - "lazy_static", "libc", "mio/os-poll", "mio/os-util", @@ -60,7 +58,7 @@ net = [ ] process = [ "bytes", - "lazy_static", + "once_cell", "libc", "mio/os-poll", "mio/os-util", @@ -69,13 +67,13 @@ process = [ "winapi/threadpoollegacyapiset", ] # Includes basic task execution capabilities -rt = ["slab"] +rt = [] rt-multi-thread = [ "num_cpus", "rt", ] signal = [ - "lazy_static", + "once_cell", "libc", "mio/os-poll", "mio/uds", @@ -83,25 +81,26 @@ signal = [ "signal-hook-registry", "winapi/consoleapi", ] -stream = ["futures-core"] sync = [] test-util = [] time = [] [dependencies] -tokio-macros = { version = "0.3.0", path = "../tokio-macros", optional = true } +tokio-macros = { version = "1.0.0", path = "../tokio-macros", optional = true } pin-project-lite = "0.2.0" # Everything else is optional... -bytes = { version = "0.6.0", optional = true } -futures-core = { version = "0.3.0", optional = true } -lazy_static = { version = "1.4.0", optional = true } +bytes = { version = "1.0.0", optional = true } +once_cell = { version = "1.5.2", optional = true } memchr = { version = "2.2", optional = true } mio = { version = "0.7.6", optional = true } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.11.0", optional = true } -slab = { version = "0.4.1", optional = true } + +# Currently unstable. The API exposed by these features may be broken at any time. +# Requires `--cfg tokio_unstable` to enable. +[target.'cfg(tokio_unstable)'.dependencies] tracing = { version = "0.1.21", default-features = false, features = ["std"], optional = true } # Not in full [target.'cfg(unix)'.dependencies] @@ -118,13 +117,15 @@ default-features = false optional = true [dev-dependencies] -tokio-test = { version = "0.3.0", path = "../tokio-test" } +tokio-test = { version = "0.4.0", path = "../tokio-test" } +tokio-stream = { version = "0.1", path = "../tokio-stream" } futures = { version = "0.3.0", features = ["async-await"] } proptest = "0.10.0" tempfile = "3.1.0" +async-stream = "0.3" [target.'cfg(loom)'.dev-dependencies] -loom = { version = "0.3.5", features = ["futures", "checkpoint"] } +loom = { version = "0.4", features = ["futures", "checkpoint"] } [build-dependencies] autocfg = "1" # Needed for conditionally enabling `track-caller` @@ -134,4 +135,4 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [package.metadata.playground] -features = ["full"] +features = ["full", "test-util"] @@ -1,4 +1,4 @@ -Copyright (c) 2019 Tokio Contributors +Copyright (c) 2020 Tokio Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated @@ -7,13 +7,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/tokio/tokio-0.3.4.crate" + value: "https://static.crates.io/crates/tokio/tokio-1.0.2.crate" } - version: "0.3.4" + version: "1.0.2" license_type: NOTICE last_upgrade_date { - year: 2020 - month: 11 - day: 18 + year: 2021 + month: 1 + day: 14 } } @@ -29,7 +29,6 @@ the Rust programming language. It is: [Website](https://tokio.rs) | [Guides](https://tokio.rs/tokio/tutorial) | [API Docs](https://docs.rs/tokio/latest/tokio) | -[Roadmap](https://github.com/tokio-rs/tokio/blob/master/ROADMAP.md) | [Chat](https://discord.gg/tokio) ## Overview @@ -55,7 +54,7 @@ A basic TCP echo server with Tokio: ```rust,no_run use tokio::net::TcpListener; -use tokio::prelude::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { @@ -133,8 +132,7 @@ several other libraries, including: * [`tower`]: A library of modular and reusable components for building robust networking clients and servers. -* [`tracing`] (formerly `tokio-trace`): A framework for application-level - tracing and async-aware diagnostics. +* [`tracing`]: A framework for application-level tracing and async-aware diagnostics. * [`rdbc`]: A Rust database connectivity library for MySQL, Postgres and SQLite. diff --git a/src/coop.rs b/src/coop.rs index 980cdf8..05b2ae8 100644 --- a/src/coop.rs +++ b/src/coop.rs @@ -13,7 +13,7 @@ //! Consider a future like this one: //! //! ``` -//! # use tokio::stream::{Stream, StreamExt}; +//! # use tokio_stream::{Stream, StreamExt}; //! async fn drop_all<I: Stream + Unpin>(mut input: I) { //! while let Some(_) = input.next().await {} //! } @@ -25,7 +25,7 @@ //! opt-in yield points, this problem is alleviated: //! //! ```ignore -//! # use tokio::stream::{Stream, StreamExt}; +//! # use tokio_stream::{Stream, StreamExt}; //! async fn drop_all<I: Stream + Unpin>(mut input: I) { //! while let Some(_) = input.next().await { //! tokio::coop::proceed().await; diff --git a/src/fs/copy.rs b/src/fs/copy.rs index 2d4556f..b47f287 100644 --- a/src/fs/copy.rs +++ b/src/fs/copy.rs @@ -20,7 +20,7 @@ use std::path::Path; /// # } /// ``` -pub async fn copy<P: AsRef<Path>, Q: AsRef<Path>>(from: P, to: Q) -> Result<u64, std::io::Error> { +pub async fn copy(from: impl AsRef<Path>, to: impl AsRef<Path>) -> Result<u64, std::io::Error> { let from = from.as_ref().to_owned(); let to = to.as_ref().to_owned(); asyncify(|| std::fs::copy(from, to)).await diff --git a/src/fs/dir_builder.rs b/src/fs/dir_builder.rs index 8752a37..b184934 100644 --- a/src/fs/dir_builder.rs +++ b/src/fs/dir_builder.rs @@ -5,14 +5,10 @@ use std::path::Path; /// A builder for creating directories in various manners. /// -/// Additional Unix-specific options are available via importing the -/// [`DirBuilderExt`] trait. -/// /// This is a specialized version of [`std::fs::DirBuilder`] for usage on /// the Tokio runtime. /// /// [std::fs::DirBuilder]: std::fs::DirBuilder -/// [`DirBuilderExt`]: crate::fs::os::unix::DirBuilderExt #[derive(Debug, Default)] pub struct DirBuilder { /// Indicates whether to create parent directories if they are missing. @@ -100,7 +96,7 @@ impl DirBuilder { /// Ok(()) /// } /// ``` - pub async fn create<P: AsRef<Path>>(&self, path: P) -> io::Result<()> { + pub async fn create(&self, path: impl AsRef<Path>) -> io::Result<()> { let path = path.as_ref().to_owned(); let mut builder = std::fs::DirBuilder::new(); builder.recursive(self.recursive); @@ -115,3 +111,27 @@ impl DirBuilder { asyncify(move || builder.create(path)).await } } + +feature! { + #![unix] + + impl DirBuilder { + /// Sets the mode to create new directories with. + /// + /// This option defaults to 0o777. + /// + /// # Examples + /// + /// + /// ```no_run + /// use tokio::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.mode(0o775); + /// ``` + pub fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = Some(mode); + self + } + } +} diff --git a/src/fs/file.rs b/src/fs/file.rs index 7c71f48..5c06e73 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -37,8 +37,7 @@ use std::task::Poll::*; /// the data to disk. /// /// Reading and writing to a `File` is usually done using the convenience -/// methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] traits. Examples -/// import these traits through [the prelude]. +/// methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] traits. /// /// [std]: struct@std::fs::File /// [`AsyncSeek`]: trait@crate::io::AsyncSeek @@ -46,7 +45,6 @@ use std::task::Poll::*; /// [`sync_all`]: fn@crate::fs::File::sync_all /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt -/// [the prelude]: crate::prelude /// /// # Examples /// @@ -54,7 +52,7 @@ use std::task::Poll::*; /// /// ```no_run /// use tokio::fs::File; -/// use tokio::prelude::*; // for write_all() +/// use tokio::io::AsyncWriteExt; // for write_all() /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; @@ -67,7 +65,7 @@ use std::task::Poll::*; /// /// ```no_run /// use tokio::fs::File; -/// use tokio::prelude::*; // for read_to_end() +/// use tokio::io::AsyncReadExt; // for read_to_end() /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::open("foo.txt").await?; @@ -125,7 +123,7 @@ impl File { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::AsyncReadExt; /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::open("foo.txt").await?; @@ -169,7 +167,7 @@ impl File { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; @@ -221,7 +219,7 @@ impl File { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; @@ -256,7 +254,7 @@ impl File { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; @@ -294,7 +292,7 @@ impl File { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// /// # async fn dox() -> std::io::Result<()> { /// let mut file = File::create("foo.txt").await?; diff --git a/src/fs/mod.rs b/src/fs/mod.rs index b9b0cd7..d4f0074 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -48,8 +48,6 @@ pub use self::metadata::metadata; mod open_options; pub use self::open_options::OpenOptions; -pub mod os; - mod read; pub use self::read::read; @@ -86,6 +84,23 @@ pub use self::write::write; mod copy; pub use self::copy::copy; +feature! { + #![unix] + + mod symlink; + pub use self::symlink::symlink; +} + +feature! { + #![windows] + + mod symlink_dir; + pub use self::symlink_dir::symlink_dir; + + mod symlink_file; + pub use self::symlink_file::symlink_file; +} + use std::io; pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs index acd99a1..fa37a60 100644 --- a/src/fs/open_options.rs +++ b/src/fs/open_options.rs @@ -389,6 +389,262 @@ impl OpenOptions { } } +feature! { + #![unix] + + use std::os::unix::fs::OpenOptionsExt; + + impl OpenOptions { + /// Sets the mode bits that a new file will be created with. + /// + /// If a new file is created as part of an `OpenOptions::open` call then this + /// specified `mode` will be used as the permission bits for the new file. + /// If no `mode` is set, the default of `0o666` will be used. + /// The operating system masks out bits with the system's `umask`, to produce + /// the final permissions. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.mode(0o644); // Give read/write for owner and read for others. + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn mode(&mut self, mode: u32) -> &mut OpenOptions { + self.as_inner_mut().mode(mode); + self + } + + /// Pass custom flags to the `flags` argument of `open`. + /// + /// The bits that define the access mode are masked out with `O_ACCMODE`, to + /// ensure they do not interfere with the access mode set by Rusts options. + /// + /// Custom flags can only set flags, not remove flags set by Rusts options. + /// This options overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use libc; + /// use tokio::fs::OpenOptions; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut options = OpenOptions::new(); + /// options.write(true); + /// if cfg!(unix) { + /// options.custom_flags(libc::O_NOFOLLOW); + /// } + /// let file = options.open("foo.txt").await?; + /// + /// Ok(()) + /// } + /// ``` + pub fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } + } +} + +feature! { + #![windows] + + use std::os::windows::fs::OpenOptionsExt; + + impl OpenOptions { + /// Overrides the `dwDesiredAccess` argument to the call to [`CreateFile`] + /// with the specified value. + /// + /// This will override the `read`, `write`, and `append` flags on the + /// `OpenOptions` structure. This method provides fine-grained control over + /// the permissions to read, write and append data, attributes (like hidden + /// and system), and extended attributes. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Open without read and write permission, for example if you only need + /// // to call `stat` on the file + /// let file = OpenOptions::new().access_mode(0).open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + pub fn access_mode(&mut self, access: u32) -> &mut OpenOptions { + self.as_inner_mut().access_mode(access); + self + } + + /// Overrides the `dwShareMode` argument to the call to [`CreateFile`] with + /// the specified value. + /// + /// By default `share_mode` is set to + /// `FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE`. This allows + /// other processes to read, write, and delete/rename the same file + /// while it is open. Removing any of the flags will prevent other + /// processes from performing the corresponding operation until the file + /// handle is closed. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// // Do not allow others to read or modify this file while we have it open + /// // for writing. + /// let file = OpenOptions::new() + /// .write(true) + /// .share_mode(0) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + pub fn share_mode(&mut self, share: u32) -> &mut OpenOptions { + self.as_inner_mut().share_mode(share); + self + } + + /// Sets extra flags for the `dwFileFlags` argument to the call to + /// [`CreateFile2`] to the specified value (or combines it with + /// `attributes` and `security_qos_flags` to set the `dwFlagsAndAttributes` + /// for [`CreateFile`]). + /// + /// Custom flags can only set flags, not remove flags set by Rust's options. + /// This option overwrites any previously set custom flags. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::FILE_FLAG_DELETE_ON_CLOSE; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .create(true) + /// .write(true) + /// .custom_flags(FILE_FLAG_DELETE_ON_CLOSE) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + pub fn custom_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().custom_flags(flags); + self + } + + /// Sets the `dwFileAttributes` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and + /// `security_qos_flags` to set the `dwFlagsAndAttributes` for + /// [`CreateFile`]). + /// + /// If a _new_ file is created because it does not yet exist and + /// `.create(true)` or `.create_new(true)` are specified, the new file is + /// given the attributes declared with `.attributes()`. + /// + /// If an _existing_ file is opened with `.create(true).truncate(true)`, its + /// existing attributes are preserved and combined with the ones declared + /// with `.attributes()`. + /// + /// In all other cases the attributes get ignored. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winnt::FILE_ATTRIBUTE_HIDDEN; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// .attributes(FILE_ATTRIBUTE_HIDDEN) + /// .open("foo.txt").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + pub fn attributes(&mut self, attributes: u32) -> &mut OpenOptions { + self.as_inner_mut().attributes(attributes); + self + } + + /// Sets the `dwSecurityQosFlags` argument to the call to [`CreateFile2`] to + /// the specified value (or combines it with `custom_flags` and `attributes` + /// to set the `dwFlagsAndAttributes` for [`CreateFile`]). + /// + /// By default `security_qos_flags` is not set. It should be specified when + /// opening a named pipe, to control to which degree a server process can + /// act on behalf of a client process (security impersonation level). + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// # Examples + /// + /// ```no_run + /// use winapi::um::winbase::SECURITY_IDENTIFICATION; + /// use tokio::fs::OpenOptions; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// + /// // Sets the flag value to `SecurityIdentification`. + /// .security_qos_flags(SECURITY_IDENTIFICATION) + /// + /// .open(r"\\.\pipe\MyPipe").await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 + /// [Impersonation Levels]: + /// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + pub fn security_qos_flags(&mut self, flags: u32) -> &mut OpenOptions { + self.as_inner_mut().security_qos_flags(flags); + self + } + } +} + impl From<std::fs::OpenOptions> for OpenOptions { fn from(options: std::fs::OpenOptions) -> OpenOptions { OpenOptions(options) diff --git a/src/fs/os/mod.rs b/src/fs/os/mod.rs deleted file mode 100644 index f4b8bfb..0000000 --- a/src/fs/os/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -//! OS-specific functionality. - -#[cfg(unix)] -pub mod unix; - -#[cfg(windows)] -pub mod windows; diff --git a/src/fs/os/unix/dir_builder_ext.rs b/src/fs/os/unix/dir_builder_ext.rs deleted file mode 100644 index ccdc552..0000000 --- a/src/fs/os/unix/dir_builder_ext.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::fs::dir_builder::DirBuilder; - -/// Unix-specific extensions to [`DirBuilder`]. -/// -/// [`DirBuilder`]: crate::fs::DirBuilder -pub trait DirBuilderExt: sealed::Sealed { - /// Sets the mode to create new directories with. - /// - /// This option defaults to 0o777. - /// - /// # Examples - /// - /// - /// ```no_run - /// use tokio::fs::DirBuilder; - /// use tokio::fs::os::unix::DirBuilderExt; - /// - /// let mut builder = DirBuilder::new(); - /// builder.mode(0o775); - /// ``` - fn mode(&mut self, mode: u32) -> &mut Self; -} - -impl DirBuilderExt for DirBuilder { - fn mode(&mut self, mode: u32) -> &mut Self { - self.mode = Some(mode); - self - } -} - -impl sealed::Sealed for DirBuilder {} - -pub(crate) mod sealed { - #[doc(hidden)] - pub trait Sealed {} -} diff --git a/src/fs/os/unix/dir_entry_ext.rs b/src/fs/os/unix/dir_entry_ext.rs deleted file mode 100644 index 2ac56da..0000000 --- a/src/fs/os/unix/dir_entry_ext.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::fs::DirEntry; -use std::os::unix::fs::DirEntryExt as _; - -/// Unix-specific extension methods for [`fs::DirEntry`]. -/// -/// This mirrors the definition of [`std::os::unix::fs::DirEntryExt`]. -/// -/// [`fs::DirEntry`]: crate::fs::DirEntry -/// [`std::os::unix::fs::DirEntryExt`]: std::os::unix::fs::DirEntryExt -pub trait DirEntryExt: sealed::Sealed { - /// Returns the underlying `d_ino` field in the contained `dirent` - /// structure. - /// - /// # Examples - /// - /// ``` - /// use tokio::fs; - /// use tokio::fs::os::unix::DirEntryExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// let mut entries = fs::read_dir(".").await?; - /// while let Some(entry) = entries.next_entry().await? { - /// // Here, `entry` is a `DirEntry`. - /// println!("{:?}: {}", entry.file_name(), entry.ino()); - /// } - /// # Ok(()) - /// # } - /// ``` - fn ino(&self) -> u64; -} - -impl DirEntryExt for DirEntry { - fn ino(&self) -> u64 { - self.as_inner().ino() - } -} - -impl sealed::Sealed for DirEntry {} - -pub(crate) mod sealed { - #[doc(hidden)] - pub trait Sealed {} -} diff --git a/src/fs/os/unix/mod.rs b/src/fs/os/unix/mod.rs deleted file mode 100644 index a0ae751..0000000 --- a/src/fs/os/unix/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! Unix-specific extensions to primitives in the `tokio_fs` module. - -mod symlink; -pub use self::symlink::symlink; - -mod open_options_ext; -pub use self::open_options_ext::OpenOptionsExt; - -mod dir_builder_ext; -pub use self::dir_builder_ext::DirBuilderExt; - -mod dir_entry_ext; -pub use self::dir_entry_ext::DirEntryExt; diff --git a/src/fs/os/unix/open_options_ext.rs b/src/fs/os/unix/open_options_ext.rs deleted file mode 100644 index 6e0fd2b..0000000 --- a/src/fs/os/unix/open_options_ext.rs +++ /dev/null @@ -1,85 +0,0 @@ -use crate::fs::open_options::OpenOptions; -use std::os::unix::fs::OpenOptionsExt as _; - -/// Unix-specific extensions to [`fs::OpenOptions`]. -/// -/// This mirrors the definition of [`std::os::unix::fs::OpenOptionsExt`]. -/// -/// [`fs::OpenOptions`]: crate::fs::OpenOptions -/// [`std::os::unix::fs::OpenOptionsExt`]: std::os::unix::fs::OpenOptionsExt -pub trait OpenOptionsExt: sealed::Sealed { - /// Sets the mode bits that a new file will be created with. - /// - /// If a new file is created as part of an `OpenOptions::open` call then this - /// specified `mode` will be used as the permission bits for the new file. - /// If no `mode` is set, the default of `0o666` will be used. - /// The operating system masks out bits with the system's `umask`, to produce - /// the final permissions. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::unix::OpenOptionsExt; - /// use std::io; - /// - /// #[tokio::main] - /// async fn main() -> io::Result<()> { - /// let mut options = OpenOptions::new(); - /// options.mode(0o644); // Give read/write for owner and read for others. - /// let file = options.open("foo.txt").await?; - /// - /// Ok(()) - /// } - /// ``` - fn mode(&mut self, mode: u32) -> &mut Self; - - /// Pass custom flags to the `flags` argument of `open`. - /// - /// The bits that define the access mode are masked out with `O_ACCMODE`, to - /// ensure they do not interfere with the access mode set by Rusts options. - /// - /// Custom flags can only set flags, not remove flags set by Rusts options. - /// This options overwrites any previously set custom flags. - /// - /// # Examples - /// - /// ```no_run - /// use libc; - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::unix::OpenOptionsExt; - /// use std::io; - /// - /// #[tokio::main] - /// async fn main() -> io::Result<()> { - /// let mut options = OpenOptions::new(); - /// options.write(true); - /// if cfg!(unix) { - /// options.custom_flags(libc::O_NOFOLLOW); - /// } - /// let file = options.open("foo.txt").await?; - /// - /// Ok(()) - /// } - /// ``` - fn custom_flags(&mut self, flags: i32) -> &mut Self; -} - -impl OpenOptionsExt for OpenOptions { - fn mode(&mut self, mode: u32) -> &mut OpenOptions { - self.as_inner_mut().mode(mode); - self - } - - fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions { - self.as_inner_mut().custom_flags(flags); - self - } -} - -impl sealed::Sealed for OpenOptions {} - -pub(crate) mod sealed { - #[doc(hidden)] - pub trait Sealed {} -} diff --git a/src/fs/os/windows/mod.rs b/src/fs/os/windows/mod.rs deleted file mode 100644 index ab98c13..0000000 --- a/src/fs/os/windows/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! Windows-specific extensions for the primitives in the `tokio_fs` module. - -mod symlink_dir; -pub use self::symlink_dir::symlink_dir; - -mod symlink_file; -pub use self::symlink_file::symlink_file; - -mod open_options_ext; -pub use self::open_options_ext::OpenOptionsExt; diff --git a/src/fs/os/windows/open_options_ext.rs b/src/fs/os/windows/open_options_ext.rs deleted file mode 100644 index ce86fba..0000000 --- a/src/fs/os/windows/open_options_ext.rs +++ /dev/null @@ -1,214 +0,0 @@ -use crate::fs::open_options::OpenOptions; -use std::os::windows::fs::OpenOptionsExt as _; - -/// Unix-specific extensions to [`fs::OpenOptions`]. -/// -/// This mirrors the definition of [`std::os::windows::fs::OpenOptionsExt`]. -/// -/// [`fs::OpenOptions`]: crate::fs::OpenOptions -/// [`std::os::windows::fs::OpenOptionsExt`]: std::os::windows::fs::OpenOptionsExt -pub trait OpenOptionsExt: sealed::Sealed { - /// Overrides the `dwDesiredAccess` argument to the call to [`CreateFile`] - /// with the specified value. - /// - /// This will override the `read`, `write`, and `append` flags on the - /// `OpenOptions` structure. This method provides fine-grained control over - /// the permissions to read, write and append data, attributes (like hidden - /// and system), and extended attributes. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::windows::OpenOptionsExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// // Open without read and write permission, for example if you only need - /// // to call `stat` on the file - /// let file = OpenOptions::new().access_mode(0).open("foo.txt").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea - fn access_mode(&mut self, access: u32) -> &mut Self; - - /// Overrides the `dwShareMode` argument to the call to [`CreateFile`] with - /// the specified value. - /// - /// By default `share_mode` is set to - /// `FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE`. This allows - /// other processes to read, write, and delete/rename the same file - /// while it is open. Removing any of the flags will prevent other - /// processes from performing the corresponding operation until the file - /// handle is closed. - /// - /// # Examples - /// - /// ```no_run - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::windows::OpenOptionsExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// // Do not allow others to read or modify this file while we have it open - /// // for writing. - /// let file = OpenOptions::new() - /// .write(true) - /// .share_mode(0) - /// .open("foo.txt").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea - fn share_mode(&mut self, val: u32) -> &mut Self; - - /// Sets extra flags for the `dwFileFlags` argument to the call to - /// [`CreateFile2`] to the specified value (or combines it with - /// `attributes` and `security_qos_flags` to set the `dwFlagsAndAttributes` - /// for [`CreateFile`]). - /// - /// Custom flags can only set flags, not remove flags set by Rust's options. - /// This option overwrites any previously set custom flags. - /// - /// # Examples - /// - /// ```no_run - /// use winapi::um::winbase::FILE_FLAG_DELETE_ON_CLOSE; - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::windows::OpenOptionsExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// let file = OpenOptions::new() - /// .create(true) - /// .write(true) - /// .custom_flags(FILE_FLAG_DELETE_ON_CLOSE) - /// .open("foo.txt").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea - /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 - fn custom_flags(&mut self, flags: u32) -> &mut Self; - - /// Sets the `dwFileAttributes` argument to the call to [`CreateFile2`] to - /// the specified value (or combines it with `custom_flags` and - /// `security_qos_flags` to set the `dwFlagsAndAttributes` for - /// [`CreateFile`]). - /// - /// If a _new_ file is created because it does not yet exist and - /// `.create(true)` or `.create_new(true)` are specified, the new file is - /// given the attributes declared with `.attributes()`. - /// - /// If an _existing_ file is opened with `.create(true).truncate(true)`, its - /// existing attributes are preserved and combined with the ones declared - /// with `.attributes()`. - /// - /// In all other cases the attributes get ignored. - /// - /// # Examples - /// - /// ```no_run - /// use winapi::um::winnt::FILE_ATTRIBUTE_HIDDEN; - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::windows::OpenOptionsExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// let file = OpenOptions::new() - /// .write(true) - /// .create(true) - /// .attributes(FILE_ATTRIBUTE_HIDDEN) - /// .open("foo.txt").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea - /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 - fn attributes(&mut self, val: u32) -> &mut Self; - - /// Sets the `dwSecurityQosFlags` argument to the call to [`CreateFile2`] to - /// the specified value (or combines it with `custom_flags` and `attributes` - /// to set the `dwFlagsAndAttributes` for [`CreateFile`]). - /// - /// By default `security_qos_flags` is not set. It should be specified when - /// opening a named pipe, to control to which degree a server process can - /// act on behalf of a client process (security impersonation level). - /// - /// When `security_qos_flags` is not set, a malicious program can gain the - /// elevated privileges of a privileged Rust process when it allows opening - /// user-specified paths, by tricking it into opening a named pipe. So - /// arguably `security_qos_flags` should also be set when opening arbitrary - /// paths. However the bits can then conflict with other flags, specifically - /// `FILE_FLAG_OPEN_NO_RECALL`. - /// - /// For information about possible values, see [Impersonation Levels] on the - /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set - /// automatically when using this method. - /// - /// # Examples - /// - /// ```no_run - /// use winapi::um::winbase::SECURITY_IDENTIFICATION; - /// use tokio::fs::OpenOptions; - /// use tokio::fs::os::windows::OpenOptionsExt; - /// - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// let file = OpenOptions::new() - /// .write(true) - /// .create(true) - /// - /// // Sets the flag value to `SecurityIdentification`. - /// .security_qos_flags(SECURITY_IDENTIFICATION) - /// - /// .open(r"\\.\pipe\MyPipe").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea - /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2 - /// [Impersonation Levels]: - /// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level - fn security_qos_flags(&mut self, flags: u32) -> &mut Self; -} - -impl OpenOptionsExt for OpenOptions { - fn access_mode(&mut self, access: u32) -> &mut OpenOptions { - self.as_inner_mut().access_mode(access); - self - } - - fn share_mode(&mut self, share: u32) -> &mut OpenOptions { - self.as_inner_mut().share_mode(share); - self - } - - fn custom_flags(&mut self, flags: u32) -> &mut OpenOptions { - self.as_inner_mut().custom_flags(flags); - self - } - - fn attributes(&mut self, attributes: u32) -> &mut OpenOptions { - self.as_inner_mut().attributes(attributes); - self - } - - fn security_qos_flags(&mut self, flags: u32) -> &mut OpenOptions { - self.as_inner_mut().security_qos_flags(flags); - self - } -} - -impl sealed::Sealed for OpenOptions {} - -pub(crate) mod sealed { - #[doc(hidden)] - pub trait Sealed {} -} diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs index 8ca583b..7b21c9c 100644 --- a/src/fs/read_dir.rs +++ b/src/fs/read_dir.rs @@ -29,12 +29,11 @@ pub async fn read_dir(path: impl AsRef<Path>) -> io::Result<ReadDir> { /// /// # Errors /// -/// This [`Stream`] will return an [`Err`] if there's some sort of intermittent +/// This stream will return an [`Err`] if there's some sort of intermittent /// IO error during iteration. /// /// [`read_dir`]: read_dir /// [`DirEntry`]: DirEntry -/// [`Stream`]: crate::stream::Stream /// [`Err`]: std::result::Result::Err #[derive(Debug)] #[must_use = "streams do nothing unless polled"] @@ -53,7 +52,25 @@ impl ReadDir { poll_fn(|cx| self.poll_next_entry(cx)).await } - fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> { + /// Polls for the next directory entry in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next directory entry is not yet available. + /// * `Poll::Ready(Ok(Some(entry)))` if the next directory entry is available. + /// * `Poll::Ready(Ok(None))` if there are no more directory entries in this + /// stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the next + /// directory entry. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when the next directory entry + /// becomes available on the underlying IO resource. + /// + /// Note that on multiple calls to `poll_next_entry`, only the `Waker` from + /// the `Context` passed to the most recent call is scheduled to receive a + /// wakeup. + pub fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> { loop { match self.0 { State::Idle(ref mut std) => { @@ -81,16 +98,33 @@ impl ReadDir { } } -#[cfg(feature = "stream")] -impl crate::stream::Stream for ReadDir { - type Item = io::Result<DirEntry>; +feature! { + #![unix] + + use std::os::unix::fs::DirEntryExt; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - Poll::Ready(match ready!(self.poll_next_entry(cx)) { - Ok(Some(entry)) => Some(Ok(entry)), - Ok(None) => None, - Err(err) => Some(Err(err)), - }) + impl DirEntry { + /// Returns the underlying `d_ino` field in the contained `dirent` + /// structure. + /// + /// # Examples + /// + /// ``` + /// use tokio::fs; + /// + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// let mut entries = fs::read_dir(".").await?; + /// while let Some(entry) = entries.next_entry().await? { + /// // Here, `entry` is a `DirEntry`. + /// println!("{:?}: {}", entry.file_name(), entry.ino()); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn ino(&self) -> u64 { + self.as_inner().ino() + } } } diff --git a/src/fs/os/unix/symlink.rs b/src/fs/symlink.rs index 22ece72..22ece72 100644 --- a/src/fs/os/unix/symlink.rs +++ b/src/fs/symlink.rs diff --git a/src/fs/os/windows/symlink_dir.rs b/src/fs/symlink_dir.rs index 736e762..736e762 100644 --- a/src/fs/os/windows/symlink_dir.rs +++ b/src/fs/symlink_dir.rs diff --git a/src/fs/os/windows/symlink_file.rs b/src/fs/symlink_file.rs index 07d8e60..07d8e60 100644 --- a/src/fs/os/windows/symlink_file.rs +++ b/src/fs/symlink_file.rs diff --git a/src/fs/write.rs b/src/fs/write.rs index 9b6cf53..0ed9082 100644 --- a/src/fs/write.rs +++ b/src/fs/write.rs @@ -19,7 +19,7 @@ use std::{io, path::Path}; /// # Ok(()) /// # } /// ``` -pub async fn write<C: AsRef<[u8]> + Unpin>(path: impl AsRef<Path>, contents: C) -> io::Result<()> { +pub async fn write(path: impl AsRef<Path>, contents: impl AsRef<[u8]>) -> io::Result<()> { let path = path.as_ref().to_owned(); let contents = contents.as_ref().to_owned(); diff --git a/src/io/async_fd.rs b/src/io/async_fd.rs index 99f23fd..08d7b91 100644 --- a/src/io/async_fd.rs +++ b/src/io/async_fd.rs @@ -56,6 +56,78 @@ use std::{task::Context, task::Poll}; /// the limitation that only one task can wait on each direction (read or write) /// at a time. /// +/// # Examples +/// +/// This example shows how to turn [`std::net::TcpStream`] asynchronous using +/// `AsyncFd`. It implements `read` as an async fn, and `AsyncWrite` as a trait +/// to show how to implement both approaches. +/// +/// ```no_run +/// use futures::ready; +/// use std::io::{self, Read, Write}; +/// use std::net::TcpStream; +/// use std::pin::Pin; +/// use std::task::{Context, Poll}; +/// use tokio::io::AsyncWrite; +/// use tokio::io::unix::AsyncFd; +/// +/// pub struct AsyncTcpStream { +/// inner: AsyncFd<TcpStream>, +/// } +/// +/// impl AsyncTcpStream { +/// pub fn new(tcp: TcpStream) -> io::Result<Self> { +/// Ok(Self { +/// inner: AsyncFd::new(tcp)?, +/// }) +/// } +/// +/// pub async fn read(&self, out: &mut [u8]) -> io::Result<usize> { +/// loop { +/// let mut guard = self.inner.readable().await?; +/// +/// match guard.try_io(|inner| inner.get_ref().read(out)) { +/// Ok(result) => return result, +/// Err(_would_block) => continue, +/// } +/// } +/// } +/// } +/// +/// impl AsyncWrite for AsyncTcpStream { +/// fn poll_write( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// buf: &[u8] +/// ) -> Poll<io::Result<usize>> { +/// loop { +/// let mut guard = ready!(self.inner.poll_write_ready(cx))?; +/// +/// match guard.try_io(|inner| inner.get_ref().write(buf)) { +/// Ok(result) => return Poll::Ready(result), +/// Err(_would_block) => continue, +/// } +/// } +/// } +/// +/// fn poll_flush( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// ) -> Poll<io::Result<()>> { +/// // tcp flush is a no-op +/// Poll::Ready(Ok(())) +/// } +/// +/// fn poll_shutdown( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_>, +/// ) -> Poll<io::Result<()>> { +/// self.inner.get_ref().shutdown(std::net::Shutdown::Write)?; +/// Poll::Ready(Ok(())) +/// } +/// } +/// ``` +/// /// [`readable`]: method@Self::readable /// [`writable`]: method@Self::writable /// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard @@ -64,35 +136,64 @@ pub struct AsyncFd<T: AsRawFd> { registration: Registration, inner: Option<T>, } -/// Represents an IO-ready event detected on a particular file descriptor, which + +/// Represents an IO-ready event detected on a particular file descriptor that /// has not yet been acknowledged. This is a `must_use` structure to help ensure /// that you do not forget to explicitly clear (or not clear) the event. +/// +/// This type exposes an immutable reference to the underlying IO object. #[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] pub struct AsyncFdReadyGuard<'a, T: AsRawFd> { async_fd: &'a AsyncFd<T>, event: Option<ReadyEvent>, } +/// Represents an IO-ready event detected on a particular file descriptor that +/// has not yet been acknowledged. This is a `must_use` structure to help ensure +/// that you do not forget to explicitly clear (or not clear) the event. +/// +/// This type exposes a mutable reference to the underlying IO object. +#[must_use = "You must explicitly choose whether to clear the readiness state by calling a method on ReadyGuard"] +pub struct AsyncFdReadyMutGuard<'a, T: AsRawFd> { + async_fd: &'a mut AsyncFd<T>, + event: Option<ReadyEvent>, +} + const ALL_INTEREST: Interest = Interest::READABLE.add(Interest::WRITABLE); impl<T: AsRawFd> AsyncFd<T> { + #[inline] /// Creates an AsyncFd backed by (and taking ownership of) an object /// implementing [`AsRawFd`]. The backing file descriptor is cached at the /// time of creation. /// - /// This function must be called in the context of a tokio runtime. + /// This method must be called in the context of a tokio runtime. pub fn new(inner: T) -> io::Result<Self> where T: AsRawFd, { - Self::new_with_handle(inner, Handle::current()) + Self::with_interest(inner, ALL_INTEREST) + } + + #[inline] + /// Creates new instance as `new` with additional ability to customize interest, + /// allowing to specify whether file descriptor will be polled for read, write or both. + pub fn with_interest(inner: T, interest: Interest) -> io::Result<Self> + where + T: AsRawFd, + { + Self::new_with_handle_and_interest(inner, Handle::current(), interest) } - pub(crate) fn new_with_handle(inner: T, handle: Handle) -> io::Result<Self> { + pub(crate) fn new_with_handle_and_interest( + inner: T, + handle: Handle, + interest: Interest, + ) -> io::Result<Self> { let fd = inner.as_raw_fd(); let registration = - Registration::new_with_interest_and_handle(&mut SourceFd(&fd), ALL_INTEREST, handle)?; + Registration::new_with_interest_and_handle(&mut SourceFd(&fd), interest, handle)?; Ok(AsyncFd { registration, @@ -122,24 +223,39 @@ impl<T: AsRawFd> AsyncFd<T> { self.inner.take() } - /// Deregisters this file descriptor, and returns ownership of the backing + /// Deregisters this file descriptor and returns ownership of the backing /// object. pub fn into_inner(mut self) -> T { self.take_inner().unwrap() } - /// Polls for read readiness. This function retains the waker for the last - /// context that called [`poll_read_ready`]; it therefore can only be used - /// by a single task at a time (however, [`poll_write_ready`] retains a - /// second, independent waker). + /// Polls for read readiness. + /// + /// If the file descriptor is not currently ready for reading, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for reading, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_read_ready`] or + /// [`poll_read_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_write_ready`] retains a second, independent waker). /// - /// This function is intended for cases where creating and pinning a future + /// This method is intended for cases where creating and pinning a future /// via [`readable`] is not feasible. Where possible, using [`readable`] is /// preferred, as this supports polling from multiple tasks at once. /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + /// /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_read_ready_mut`]: method@Self::poll_read_ready_mut /// [`poll_write_ready`]: method@Self::poll_write_ready /// [`readable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake pub fn poll_read_ready<'a>( &'a self, cx: &mut Context<'_>, @@ -153,18 +269,71 @@ impl<T: AsRawFd> AsyncFd<T> { .into() } - /// Polls for write readiness. This function retains the waker for the last - /// context that called [`poll_write_ready`]; it therefore can only be used - /// by a single task at a time (however, [`poll_read_ready`] retains a - /// second, independent waker). + /// Polls for read readiness. + /// + /// If the file descriptor is not currently ready for reading, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for reading, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_read_ready`] or + /// [`poll_read_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_write_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_read_ready_mut`]: method@Self::poll_read_ready_mut + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`readable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_read_ready_mut<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyMutGuard<'a, T>>> { + let event = ready!(self.registration.poll_read_ready(cx))?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + + /// Polls for write readiness. + /// + /// If the file descriptor is not currently ready for writing, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for writing, [`Waker::wake`] will be called. /// - /// This function is intended for cases where creating and pinning a future + /// Note that on multiple calls to [`poll_write_ready`] or + /// [`poll_write_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_read_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future /// via [`writable`] is not feasible. Where possible, using [`writable`] is /// preferred, as this supports polling from multiple tasks at once. /// + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + /// /// [`poll_read_ready`]: method@Self::poll_read_ready /// [`poll_write_ready`]: method@Self::poll_write_ready - /// [`writable`]: method@Self::writable + /// [`poll_write_ready_mut`]: method@Self::poll_write_ready_mut + /// [`writable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake pub fn poll_write_ready<'a>( &'a self, cx: &mut Context<'_>, @@ -178,6 +347,44 @@ impl<T: AsRawFd> AsyncFd<T> { .into() } + /// Polls for write readiness. + /// + /// If the file descriptor is not currently ready for writing, this method + /// will store a clone of the [`Waker`] from the provided [`Context`]. When the + /// file descriptor becomes ready for writing, [`Waker::wake`] will be called. + /// + /// Note that on multiple calls to [`poll_write_ready`] or + /// [`poll_write_ready_mut`], only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// [`poll_read_ready`] retains a second, independent waker). + /// + /// This method is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + /// + /// [`poll_read_ready`]: method@Self::poll_read_ready + /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`poll_write_ready_mut`]: method@Self::poll_write_ready_mut + /// [`writable`]: method@Self::readable + /// [`Context`]: struct@std::task::Context + /// [`Waker`]: struct@std::task::Waker + /// [`Waker::wake`]: method@std::task::Waker::wake + pub fn poll_write_ready_mut<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll<io::Result<AsyncFdReadyMutGuard<'a, T>>> { + let event = ready!(self.registration.poll_write_ready(cx))?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + .into() + } + async fn readiness(&self, interest: Interest) -> io::Result<AsyncFdReadyGuard<'_, T>> { let event = self.registration.readiness(interest).await?; @@ -187,21 +394,65 @@ impl<T: AsRawFd> AsyncFd<T> { }) } + async fn readiness_mut( + &mut self, + interest: Interest, + ) -> io::Result<AsyncFdReadyMutGuard<'_, T>> { + let event = self.registration.readiness(interest).await?; + + Ok(AsyncFdReadyMutGuard { + async_fd: self, + event: Some(event), + }) + } + /// Waits for the file descriptor to become readable, returning a - /// [`AsyncFdReadyGuard`] that must be dropped to resume read-readiness polling. + /// [`AsyncFdReadyGuard`] that must be dropped to resume read-readiness + /// polling. /// - /// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard - pub async fn readable(&self) -> io::Result<AsyncFdReadyGuard<'_, T>> { + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn readable<'a>(&'a self) -> io::Result<AsyncFdReadyGuard<'a, T>> { self.readiness(Interest::READABLE).await } + /// Waits for the file descriptor to become readable, returning a + /// [`AsyncFdReadyMutGuard`] that must be dropped to resume read-readiness + /// polling. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn readable_mut<'a>(&'a mut self) -> io::Result<AsyncFdReadyMutGuard<'a, T>> { + self.readiness_mut(Interest::READABLE).await + } + /// Waits for the file descriptor to become writable, returning a - /// [`AsyncFdReadyGuard`] that must be dropped to resume write-readiness polling. + /// [`AsyncFdReadyGuard`] that must be dropped to resume write-readiness + /// polling. /// - /// [`AsyncFdReadyGuard`]: struct@self::AsyncFdReadyGuard - pub async fn writable(&self) -> io::Result<AsyncFdReadyGuard<'_, T>> { + /// This method takes `&self`, so it is possible to call this method + /// concurrently with other methods on this struct. This method only + /// provides shared access to the inner IO resource when handling the + /// [`AsyncFdReadyGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn writable<'a>(&'a self) -> io::Result<AsyncFdReadyGuard<'a, T>> { self.readiness(Interest::WRITABLE).await } + + /// Waits for the file descriptor to become writable, returning a + /// [`AsyncFdReadyMutGuard`] that must be dropped to resume write-readiness + /// polling. + /// + /// This method takes `&mut self`, so it is possible to access the inner IO + /// resource mutably when handling the [`AsyncFdReadyMutGuard`]. + #[allow(clippy::needless_lifetimes)] // The lifetime improves rustdoc rendering. + pub async fn writable_mut<'a>(&'a mut self) -> io::Result<AsyncFdReadyMutGuard<'a, T>> { + self.readiness_mut(Interest::WRITABLE).await + } } impl<T: AsRawFd> AsRawFd for AsyncFd<T> { @@ -241,7 +492,7 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { } } - /// This function should be invoked when you intentionally want to keep the + /// This method should be invoked when you intentionally want to keep the /// ready flag asserted. /// /// While this function is itself a no-op, it satisfies the `#[must_use]` @@ -250,8 +501,12 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { // no-op } - /// Performs the IO operation `f`; if `f` returns a [`WouldBlock`] error, - /// the readiness state associated with this file descriptor is cleared. + /// Performs the provided IO operation. + /// + /// If `f` returns a [`WouldBlock`] error, the readiness state associated + /// with this file descriptor is cleared, and the method returns + /// `Err(TryIoError::WouldBlock)`. You will typically need to poll the + /// `AsyncFd` again when this happens. /// /// This method helps ensure that the readiness state of the underlying file /// descriptor remains in sync with the tokio-side readiness state, by @@ -262,8 +517,11 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { /// create this `AsyncFdReadyGuard`. /// /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock - pub fn with_io<R>(&mut self, f: impl FnOnce() -> io::Result<R>) -> io::Result<R> { - let result = f(); + pub fn try_io<R>( + &mut self, + f: impl FnOnce(&AsyncFd<Inner>) -> io::Result<R>, + ) -> Result<io::Result<R>, TryIoError> { + let result = f(self.async_fd); if let Err(e) = result.as_ref() { if e.kind() == io::ErrorKind::WouldBlock { @@ -271,29 +529,71 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { } } - result + match result { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Err(TryIoError(())), + result => Ok(result), + } } +} - /// Performs the IO operation `f`; if `f` returns [`Pending`], the readiness - /// state associated with this file descriptor is cleared. +impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { + /// Indicates to tokio that the file descriptor is no longer ready. The + /// internal readiness flag will be cleared, and tokio will wait for the + /// next edge-triggered readiness notification from the OS. + /// + /// It is critical that this function not be called unless your code + /// _actually observes_ that the file descriptor is _not_ ready. Do not call + /// it simply because, for example, a read succeeded; it should be called + /// when a read is observed to block. + /// + /// [`drop`]: method@std::mem::drop + pub fn clear_ready(&mut self) { + if let Some(event) = self.event.take() { + self.async_fd.registration.clear_readiness(event); + } + } + + /// This method should be invoked when you intentionally want to keep the + /// ready flag asserted. + /// + /// While this function is itself a no-op, it satisfies the `#[must_use]` + /// constraint on the [`AsyncFdReadyGuard`] type. + pub fn retain_ready(&mut self) { + // no-op + } + + /// Performs the provided IO operation. + /// + /// If `f` returns a [`WouldBlock`] error, the readiness state associated + /// with this file descriptor is cleared, and the method returns + /// `Err(TryIoError::WouldBlock)`. You will typically need to poll the + /// `AsyncFd` again when this happens. /// /// This method helps ensure that the readiness state of the underlying file /// descriptor remains in sync with the tokio-side readiness state, by - /// clearing the tokio-side state only when a [`Pending`] condition occurs. - /// It is the responsibility of the caller to ensure that `f` returns - /// [`Pending`] only if the file descriptor that originated this + /// clearing the tokio-side state only when a [`WouldBlock`] condition + /// occurs. It is the responsibility of the caller to ensure that `f` + /// returns [`WouldBlock`] only if the file descriptor that originated this /// `AsyncFdReadyGuard` no longer expresses the readiness state that was queried to /// create this `AsyncFdReadyGuard`. /// - /// [`Pending`]: std::task::Poll::Pending - pub fn with_poll<R>(&mut self, f: impl FnOnce() -> std::task::Poll<R>) -> std::task::Poll<R> { - let result = f(); + /// [`WouldBlock`]: std::io::ErrorKind::WouldBlock + pub fn try_io<R>( + &mut self, + f: impl FnOnce(&mut AsyncFd<Inner>) -> io::Result<R>, + ) -> Result<io::Result<R>, TryIoError> { + let result = f(&mut self.async_fd); - if result.is_pending() { - self.clear_ready(); + if let Err(e) = result.as_ref() { + if e.kind() == io::ErrorKind::WouldBlock { + self.clear_ready(); + } } - result + match result { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Err(TryIoError(())), + result => Ok(result), + } } } @@ -304,3 +604,20 @@ impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyGuard<'a, .finish() } } + +impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyMutGuard<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MutReadyGuard") + .field("async_fd", &self.async_fd) + .finish() + } +} + +/// The error type returned by [`try_io`]. +/// +/// This error indicates that the IO resource returned a [`WouldBlock`] error. +/// +/// [`WouldBlock`]: std::io::ErrorKind::WouldBlock +/// [`try_io`]: method@AsyncFdReadyGuard::try_io +#[derive(Debug)] +pub struct TryIoError(()); diff --git a/src/io/mod.rs b/src/io/mod.rs index 14be3e0..3e7c943 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -144,7 +144,7 @@ //! that implements [`AsyncRead`] and [`AsyncWrite`] into a `Sink`/`Stream` of //! your structured data. //! -//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html +//! [tokio-util]: https://docs.rs/tokio-util/0.6/tokio_util/codec/index.html //! //! # Standard input and output //! @@ -169,16 +169,16 @@ //! [`AsyncWrite`]: trait@AsyncWrite //! [`AsyncReadExt`]: trait@AsyncReadExt //! [`AsyncWriteExt`]: trait@AsyncWriteExt -//! ["codec"]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html -//! [`Encoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Encoder.html -//! [`Decoder`]: https://docs.rs/tokio-util/0.3/tokio_util/codec/trait.Decoder.html +//! ["codec"]: https://docs.rs/tokio-util/0.6/tokio_util/codec/index.html +//! [`Encoder`]: https://docs.rs/tokio-util/0.6/tokio_util/codec/trait.Encoder.html +//! [`Decoder`]: https://docs.rs/tokio-util/0.6/tokio_util/codec/trait.Decoder.html //! [`Error`]: struct@Error //! [`ErrorKind`]: enum@ErrorKind //! [`Result`]: type@Result //! [`Read`]: std::io::Read //! [`SeekFrom`]: enum@SeekFrom //! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html -//! [`Stream`]: crate::stream::Stream +//! [`Stream`]: https://docs.rs/futures/0.3/futures/stream/trait.Stream.html //! [`Write`]: std::io::Write cfg_io_blocking! { pub(crate) mod blocking; @@ -222,7 +222,7 @@ cfg_net_unix! { pub mod unix { //! Asynchronous IO structures specific to Unix-like operating systems. - pub use super::async_fd::{AsyncFd, AsyncFdReadyGuard}; + pub use super::async_fd::{AsyncFd, AsyncFdReadyGuard, AsyncFdReadyMutGuard, TryIoError}; } } diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs index 3a65961..0ecdb18 100644 --- a/src/io/poll_evented.rs +++ b/src/io/poll_evented.rs @@ -124,6 +124,14 @@ impl<E: Source> PollEvented<E> { pub(crate) fn registration(&self) -> &Registration { &self.registration } + + /// Deregister the inner io from the registration and returns a Result containing the inner io + #[cfg(feature = "net")] + pub(crate) fn into_inner(mut self) -> io::Result<E> { + let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here. + self.registration.deregister(&mut inner)?; + Ok(inner) + } } feature! { diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 9e87f2f..7977a0e 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -228,7 +228,6 @@ cfg_io_util! { /// /// ``` /// use tokio::io::AsyncBufReadExt; - /// use tokio::stream::StreamExt; /// /// use std::io::Cursor; /// @@ -236,12 +235,12 @@ cfg_io_util! { /// async fn main() { /// let cursor = Cursor::new(b"lorem\nipsum\r\ndolor"); /// - /// let mut lines = cursor.lines().map(|res| res.unwrap()); + /// let mut lines = cursor.lines(); /// - /// assert_eq!(lines.next().await, Some(String::from("lorem"))); - /// assert_eq!(lines.next().await, Some(String::from("ipsum"))); - /// assert_eq!(lines.next().await, Some(String::from("dolor"))); - /// assert_eq!(lines.next().await, None); + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("lorem"))); + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("ipsum"))); + /// assert_eq!(lines.next_line().await.unwrap(), Some(String::from("dolor"))); + /// assert_eq!(lines.next_line().await.unwrap(), None); /// } /// ``` /// diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index 1f918f1..ebcbce6 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -39,11 +39,9 @@ cfg_io_util! { /// [`AsyncRead`] types. Callers will tend to import this trait instead of /// [`AsyncRead`]. /// - /// As a convenience, this trait may be imported using the [`prelude`]: - /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::{self, AsyncReadExt}; /// /// #[tokio::main] /// async fn main() -> io::Result<()> { @@ -60,7 +58,6 @@ cfg_io_util! { /// See [module][crate::io] documentation for more details. /// /// [`AsyncRead`]: AsyncRead - /// [`prelude`]: crate::prelude pub trait AsyncReadExt: AsyncRead { /// Creates a new `AsyncRead` instance that chains this stream with /// `next`. diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index 351900b..813913f 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -3,15 +3,13 @@ use crate::io::AsyncSeek; use std::io::SeekFrom; cfg_io_util! { - /// An extension trait which adds utility methods to [`AsyncSeek`] types. - /// - /// As a convenience, this trait may be imported using the [`prelude`]: + /// An extension trait that adds utility methods to [`AsyncSeek`] types. /// /// # Examples /// /// ``` - /// use std::io::{Cursor, SeekFrom}; - /// use tokio::prelude::*; + /// use std::io::{self, Cursor, SeekFrom}; + /// use tokio::io::{AsyncSeekExt, AsyncReadExt}; /// /// #[tokio::main] /// async fn main() -> io::Result<()> { @@ -32,7 +30,6 @@ cfg_io_util! { /// See [module][crate::io] documentation for more details. /// /// [`AsyncSeek`]: AsyncSeek - /// [`prelude`]: crate::prelude pub trait AsyncSeekExt: AsyncSeek { /// Creates a future which will seek an IO object, and then yield the /// new position in the object and the object itself. @@ -50,7 +47,7 @@ cfg_io_util! { /// /// ```no_run /// use tokio::fs::File; - /// use tokio::prelude::*; + /// use tokio::io::{AsyncSeekExt, AsyncReadExt}; /// /// use std::io::SeekFrom; /// diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index e6ef5b2..dc500f2 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -39,10 +39,8 @@ cfg_io_util! { /// [`AsyncWrite`] types. Callers will tend to import this trait instead of /// [`AsyncWrite`]. /// - /// As a convenience, this trait may be imported using the [`prelude`]: - /// /// ```no_run - /// use tokio::prelude::*; + /// use tokio::io::{self, AsyncWriteExt}; /// use tokio::fs::File; /// /// #[tokio::main] @@ -64,7 +62,6 @@ cfg_io_util! { /// See [module][crate::io] documentation for more details. /// /// [`AsyncWrite`]: AsyncWrite - /// [`prelude`]: crate::prelude pub trait AsyncWriteExt: AsyncWrite { /// Writes a buffer into this writer, returning how many bytes were /// written. diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index b41f04a..25df78e 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -83,7 +83,23 @@ impl<R> Lines<R> where R: AsyncBufRead, { - fn poll_next_line( + /// Polls for the next line in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next line is not yet available. + /// * `Poll::Ready(Ok(Some(line)))` if the next line is available. + /// * `Poll::Ready(Ok(None))` if there are no more lines in this stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the next line. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when more bytes become + /// available on the underlying IO resource. + /// + /// Note that on multiple calls to `poll_next_line`, only the `Waker` from + /// the `Context` passed to the most recent call is scheduled to receive a + /// wakeup. + pub fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<String>>> { @@ -108,19 +124,6 @@ where } } -#[cfg(feature = "stream")] -impl<R: AsyncBufRead> crate::stream::Stream for Lines<R> { - type Item = io::Result<String>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - Poll::Ready(match ready!(self.poll_next_line(cx)) { - Ok(Some(line)) => Some(Ok(line)), - Ok(None) => None, - Err(err) => Some(Err(err)), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs index 696deef..8ec57c0 100644 --- a/src/io/util/read_buf.rs +++ b/src/io/util/read_buf.rs @@ -50,7 +50,7 @@ where } let n = { - let dst = me.buf.bytes_mut(); + let dst = me.buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; let mut buf = ReadBuf::uninit(dst); let ptr = buf.filled().as_ptr(); diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs index a974625..1aee681 100644 --- a/src/io/util/read_to_end.rs +++ b/src/io/util/read_to_end.rs @@ -72,14 +72,13 @@ fn poll_read_to_end<R: AsyncRead + ?Sized>( let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); + let ptr = unused_capacity.filled().as_ptr(); ready!(read.poll_read(cx, &mut unused_capacity))?; + assert_eq!(ptr, unused_capacity.filled().as_ptr()); let n = unused_capacity.filled().len(); let new_len = buf.len() + n; - // This should no longer even be possible in safe Rust. An implementor - // would need to have unsafely *replaced* the buffer inside `ReadBuf`, - // which... yolo? assert!(new_len <= buf.capacity()); unsafe { buf.set_len(new_len); @@ -98,7 +97,7 @@ fn reserve(buf: &mut Vec<u8>, bytes: usize) { /// Returns the unused capacity of the provided vector. fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] { - let uninit = bytes::BufMut::bytes_mut(buf); + let uninit = bytes::BufMut::chunk_mut(buf); unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index 492e26a..eb82865 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -65,7 +65,24 @@ impl<R> Split<R> where R: AsyncBufRead, { - fn poll_next_segment( + /// Polls for the next segment in the stream. + /// + /// This method returns: + /// + /// * `Poll::Pending` if the next segment is not yet available. + /// * `Poll::Ready(Ok(Some(segment)))` if the next segment is available. + /// * `Poll::Ready(Ok(None))` if there are no more segments in this stream. + /// * `Poll::Ready(Err(err))` if an IO error occurred while reading the + /// next segment. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when more bytes become + /// available on the underlying IO resource. + /// + /// Note that on multiple calls to `poll_next_segment`, only the `Waker` + /// from the `Context` passed to the most recent call is scheduled to + /// receive a wakeup. + pub fn poll_next_segment( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<Vec<u8>>>> { @@ -89,19 +106,6 @@ where } } -#[cfg(feature = "stream")] -impl<R: AsyncBufRead> crate::stream::Stream for Split<R> { - type Item = io::Result<Vec<u8>>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - Poll::Ready(match ready!(self.poll_next_segment(cx)) { - Ok(Some(segment)) => Some(Ok(segment)), - Ok(None) => None, - Err(err) => Some(Err(err)), - }) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs index 1310e5c..82fd7a7 100644 --- a/src/io/util/write_buf.rs +++ b/src/io/util/write_buf.rs @@ -48,7 +48,7 @@ where return Poll::Ready(Ok(0)); } - let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?; + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.chunk()))?; me.buf.advance(n); Poll::Ready(Ok(n)) } @@ -1,4 +1,4 @@ -#![doc(html_root_url = "https://docs.rs/tokio/0.3.4")] +#![doc(html_root_url = "https://docs.rs/tokio/1.0.2")] #![allow( clippy::cognitive_complexity, clippy::large_enum_variant, @@ -57,7 +57,7 @@ //! enabling the `full` feature flag: //! //! ```toml -//! tokio = { version = "0.3", features = ["full"] } +//! tokio = { version = "1", features = ["full"] } //! ``` //! //! ### Authoring applications @@ -72,7 +72,7 @@ //! This example shows the quickest way to get started with Tokio. //! //! ```toml -//! tokio = { version = "0.3", features = ["full"] } +//! tokio = { version = "1", features = ["full"] } //! ``` //! //! ### Authoring libraries @@ -88,7 +88,7 @@ //! needs to `tokio::spawn` and use a `TcpStream`. //! //! ```toml -//! tokio = { version = "0.3", features = ["rt", "net"] } +//! tokio = { version = "1", features = ["rt", "net"] } //! ``` //! //! ## Working With Tasks @@ -173,16 +173,18 @@ //! combat this, Tokio provides two kinds of threads: Core threads and blocking //! threads. The core threads are where all asynchronous code runs, and Tokio //! will by default spawn one for each CPU core. The blocking threads are -//! spawned on demand, and can be used to run blocking code that would otherwise -//! block other tasks from running. Since it is not possible for Tokio to swap -//! out blocking tasks, like it can do with asynchronous code, the upper limit -//! on the number of blocking threads is very large. These limits can be -//! configured on the [`Builder`]. +//! spawned on demand, can be used to run blocking code that would otherwise +//! block other tasks from running and are kept alive when not used for a certain +//! amount of time which can be configured with [`thread_keep_alive`]. +//! Since it is not possible for Tokio to swap out blocking tasks, like it +//! can do with asynchronous code, the upper limit on the number of blocking +//! threads is very large. These limits can be configured on the [`Builder`]. //! //! To spawn a blocking task, you should use the [`spawn_blocking`] function. //! //! [`Builder`]: crate::runtime::Builder //! [`spawn_blocking`]: crate::task::spawn_blocking() +//! [`thread_keep_alive`]: crate::runtime::Builder::thread_keep_alive() //! //! ``` //! #[tokio::main] @@ -239,7 +241,7 @@ //! [`std::io`]: std::io //! [`tokio::net`]: crate::net //! [TCP]: crate::net::tcp -//! [UDP]: crate::net::udp +//! [UDP]: crate::net::UdpSocket //! [UDS]: crate::net::unix //! [`tokio::fs`]: crate::fs //! [`std::fs`]: std::fs @@ -252,7 +254,7 @@ //! //! ```no_run //! use tokio::net::TcpListener; -//! use tokio::prelude::*; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; //! //! #[tokio::main] //! async fn main() -> Result<(), Box<dyn std::error::Error>> { @@ -312,7 +314,6 @@ //! - `process`: Enables `tokio::process` types. //! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros. //! - `sync`: Enables all `tokio::sync` types. -//! - `stream`: Enables optional `Stream` implementations for types within Tokio. //! - `signal`: Enables all `tokio::signal` types. //! - `fs`: Enables `tokio::fs` types. //! - `test-util`: Enables testing based infrastructure for the Tokio runtime. @@ -330,6 +331,15 @@ //! synchronization primitives internally. MSRV may increase according to the //! _parking_lot_ release in use. //! +//! ### Unstable features +//! +//! These feature flags enable **unstable** features. The public API may break in 1.x +//! releases. To enable these features, the `--cfg tokio_unstable` must be passed to +//! `rustc` when compiling. This is easiest done using the `RUSTFLAGS` env variable: +//! `RUSTFLAGS="--cfg tokio_unstable"`. +//! +//! - `tracing`: Enables tracing events. +//! //! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section // Includes re-exports used by macros. @@ -352,8 +362,6 @@ pub mod net; mod loom; mod park; -pub mod prelude; - cfg_process! { pub mod process; } @@ -378,10 +386,6 @@ cfg_signal_internal! { pub(crate) mod signal; } -cfg_stream! { - pub mod stream; -} - cfg_sync! { pub mod sync; } @@ -400,6 +404,45 @@ cfg_time! { mod util; +/// Due to the `Stream` trait's inclusion in `std` landing later than Tokio's 1.0 +/// release, most of the Tokio stream utilities have been moved into the [`tokio-stream`] +/// crate. +/// +/// # Why was `Stream` not included in Tokio 1.0? +/// +/// Originally, we had planned to ship Tokio 1.0 with a stable `Stream` type +/// but unfortunetly the [RFC] had not been merged in time for `Stream` to +/// reach `std` on a stable compiler in time for the 1.0 release of Tokio. For +/// this reason, the team has decided to move all `Stream` based utilities to +/// the [`tokio-stream`] crate. While this is not ideal, once `Stream` has made +/// it into the standard library and the MSRV period has passed, we will implement +/// stream for our different types. +/// +/// While this may seem unfortunate, not all is lost as you can get much of the +/// `Stream` support with `async/await` and `while let` loops. It is also possible +/// to create a `impl Stream` from `async fn` using the [`async-stream`] crate. +/// +/// [`tokio-stream`]: https://docs.rs/tokio-stream +/// [`async-stream`]: https://docs.rs/async-stream +/// [RFC]: https://github.com/rust-lang/rfcs/pull/2996 +/// +/// # Example +/// +/// Convert a [`sync::mpsc::Receiver`] to an `impl Stream`. +/// +/// ```rust,no_run +/// use tokio::sync::mpsc; +/// +/// let (tx, mut rx) = mpsc::channel::<usize>(16); +/// +/// let stream = async_stream::stream! { +/// while let Some(item) = rx.recv().await { +/// yield item; +/// } +/// }; +/// ``` +pub mod stream {} + cfg_macros! { /// Implementation detail of the `select!` macro. This macro is **not** /// intended to be used as part of the public API and is permitted to @@ -408,17 +451,14 @@ cfg_macros! { pub use tokio_macros::select_priv_declare_output_enum; cfg_rt! { - cfg_rt_multi_thread! { - // This is the docs.rs case (with all features) so make sure macros - // is included in doc(cfg). + #[cfg(feature = "rt-multi-thread")] + #[cfg(not(test))] // Work around for rust-lang/rust#62127 + #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + pub use tokio_macros::main; - #[cfg(not(test))] // Work around for rust-lang/rust#62127 - #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - pub use tokio_macros::main; - - #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - pub use tokio_macros::test; - } + #[cfg(feature = "rt-multi-thread")] + #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + pub use tokio_macros::test; cfg_not_rt_multi_thread! { #[cfg(not(test))] // Work around for rust-lang/rust#62127 diff --git a/src/loom/std/mod.rs b/src/loom/std/mod.rs index 414ef90..c3f74ef 100644 --- a/src/loom/std/mod.rs +++ b/src/loom/std/mod.rs @@ -47,7 +47,7 @@ pub(crate) mod rand { } pub(crate) mod sync { - pub(crate) use std::sync::Arc; + pub(crate) use std::sync::{Arc, Weak}; // Below, make sure all the feature-influenced types are exported for // internal use. Note however that some are not _currently_ named by diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs index 1521656..9ae098f 100644 --- a/src/macros/cfg.rs +++ b/src/macros/cfg.rs @@ -241,16 +241,6 @@ macro_rules! cfg_not_signal_internal { } } -macro_rules! cfg_stream { - ($($item:item)*) => { - $( - #[cfg(feature = "stream")] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] - $item - )* - } -} - macro_rules! cfg_sync { ($($item:item)*) => { $( @@ -334,7 +324,7 @@ macro_rules! cfg_not_time { macro_rules! cfg_trace { ($($item:item)*) => { $( - #[cfg(feature = "tracing")] + #[cfg(all(tokio_unstable, feature = "tracing"))] #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] $item )* @@ -344,7 +334,7 @@ macro_rules! cfg_trace { macro_rules! cfg_not_trace { ($($item:item)*) => { $( - #[cfg(not(feature = "tracing"))] + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] $item )* } @@ -361,7 +351,6 @@ macro_rules! cfg_coop { feature = "rt", feature = "signal", feature = "sync", - feature = "stream", feature = "time", ))] $item diff --git a/src/macros/pin.rs b/src/macros/pin.rs index ed844ef..a32187e 100644 --- a/src/macros/pin.rs +++ b/src/macros/pin.rs @@ -71,7 +71,7 @@ /// /// ``` /// use tokio::{pin, select}; -/// use tokio::stream::{self, StreamExt}; +/// use tokio_stream::{self as stream, StreamExt}; /// /// async fn my_async_fn() { /// // async logic here diff --git a/src/macros/select.rs b/src/macros/select.rs index b63abdd..ca4f963 100644 --- a/src/macros/select.rs +++ b/src/macros/select.rs @@ -76,7 +76,8 @@ /// /// #[tokio::main] /// async fn main() { -/// let mut sleep = time::sleep(Duration::from_millis(50)); +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); /// /// while !sleep.is_elapsed() { /// tokio::select! { @@ -109,7 +110,8 @@ /// /// #[tokio::main] /// async fn main() { -/// let mut sleep = time::sleep(Duration::from_millis(50)); +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); /// /// loop { /// tokio::select! { @@ -167,7 +169,7 @@ /// Basic stream selecting. /// /// ``` -/// use tokio::stream::{self, StreamExt}; +/// use tokio_stream::{self as stream, StreamExt}; /// /// #[tokio::main] /// async fn main() { @@ -188,7 +190,7 @@ /// is complete, all calls to `next()` return `None`. /// /// ``` -/// use tokio::stream::{self, StreamExt}; +/// use tokio_stream::{self as stream, StreamExt}; /// /// #[tokio::main] /// async fn main() { @@ -220,13 +222,14 @@ /// Here, a stream is consumed for at most 1 second. /// /// ``` -/// use tokio::stream::{self, StreamExt}; +/// use tokio_stream::{self as stream, StreamExt}; /// use tokio::time::{self, Duration}; /// /// #[tokio::main] /// async fn main() { /// let mut stream = stream::iter(vec![1, 2, 3]); -/// let mut sleep = time::sleep(Duration::from_secs(1)); +/// let sleep = time::sleep(Duration::from_secs(1)); +/// tokio::pin!(sleep); /// /// loop { /// tokio::select! { diff --git a/src/net/mod.rs b/src/net/mod.rs index b7365e6..2f17f9e 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -36,8 +36,8 @@ cfg_net! { pub use tcp::socket::TcpSocket; pub use tcp::stream::TcpStream; - pub mod udp; - pub use udp::socket::UdpSocket; + mod udp; + pub use udp::UdpSocket; } cfg_net_unix! { diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 8b0a480..a2a8637 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -11,10 +11,8 @@ use std::task::{Context, Poll}; cfg_net! { /// A TCP socket server, listening for connections. /// - /// You can accept a new connection by using the [`accept`](`TcpListener::accept`) method. Alternatively `TcpListener` - /// implements the [`Stream`](`crate::stream::Stream`) trait, which allows you to use the listener in places that want a - /// stream. The stream will never return `None` and will also not yield the peer's `SocketAddr` structure. Iterating over - /// it is equivalent to calling accept in a loop. + /// You can accept a new connection by using the [`accept`](`TcpListener::accept`) + /// method. /// /// # Errors /// @@ -47,24 +45,6 @@ cfg_net! { /// } /// } /// ``` - /// - /// Using `impl Stream`: - /// ```no_run - /// use tokio::{net::TcpListener, stream::StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); - /// while let Some(stream) = listener.next().await { - /// match stream { - /// Ok(stream) => { - /// println!("new client!"); - /// } - /// Err(e) => { /* connection failed */ } - /// } - /// } - /// } - /// ``` pub struct TcpListener { io: PollEvented<mio::net::TcpListener>, } @@ -175,11 +155,9 @@ impl TcpListener { /// Polls to accept a new incoming connection to this listener. /// /// If there is no connection to accept, `Poll::Pending` is returned and the - /// current task will be notified by a waker. - /// - /// When ready, the most recent task that called `poll_accept` is notified. - /// The caller is responsible to ensure that `poll_accept` is called from a - /// single task. Failing to do this could result in tasks hanging. + /// current task will be notified by a waker. Note that on multiple calls + /// to `poll_accept`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> { loop { let ev = ready!(self.io.registration().poll_read_ready(cx))?; @@ -323,16 +301,6 @@ impl TcpListener { } } -#[cfg(feature = "stream")] -impl crate::stream::Stream for TcpListener { - type Item = io::Result<TcpStream>; - - fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let (socket, _) = ready!(self.poll_accept(cx))?; - Poll::Ready(Some(Ok(socket))) - } -} - impl TryFrom<net::TcpListener> for TcpListener { type Error = io::Error; diff --git a/src/net/tcp/split.rs b/src/net/tcp/split.rs index 28c94eb..78bd688 100644 --- a/src/net/tcp/split.rs +++ b/src/net/tcp/split.rs @@ -20,12 +20,11 @@ use std::task::{Context, Poll}; /// Borrowed read half of a [`TcpStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the -/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// [`AsyncReadExt`] trait. /// /// [`TcpStream`]: TcpStream /// [`split`]: TcpStream::split() /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct ReadHalf<'a>(&'a TcpStream); @@ -35,14 +34,13 @@ pub struct ReadHalf<'a>(&'a TcpStream); /// shut down the TCP stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found -/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// on the [`AsyncWriteExt`] trait. /// /// [`TcpStream`]: TcpStream /// [`split`]: TcpStream::split() /// [`AsyncWrite`]: trait@crate::io::AsyncWrite /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct WriteHalf<'a>(&'a TcpStream); @@ -55,12 +53,16 @@ impl ReadHalf<'_> { /// the queue, registering the current task for wakeup if data is not yet /// available. /// + /// Note that on multiple calls to `poll_peek` or `poll_read`, only the + /// `Waker` from the `Context` passed to the most recent call is scheduled + /// to receive a wakeup. + /// /// See the [`TcpStream::poll_peek`] level documenation for more details. /// /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -70,6 +72,7 @@ impl ReadHalf<'_> { /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?; /// let (mut read_half, _) = stream.split(); /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// read_half.poll_peek(cx, &mut buf) @@ -80,7 +83,11 @@ impl ReadHalf<'_> { /// ``` /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek - pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { self.0.poll_peek(cx, buf) } @@ -96,7 +103,7 @@ impl ReadHalf<'_> { /// /// ```no_run /// use tokio::net::TcpStream; - /// use tokio::prelude::*; + /// use tokio::io::AsyncReadExt; /// use std::error::Error; /// /// #[tokio::main] @@ -124,7 +131,8 @@ impl ReadHalf<'_> { /// [`read`]: fn@crate::io::AsyncReadExt::read /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_peek(cx, buf)).await + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } } @@ -167,7 +175,7 @@ impl AsyncWrite for WriteHalf<'_> { // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - self.0.shutdown(Shutdown::Write).into() + self.0.shutdown_std(Shutdown::Write).into() } } diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index 8d77c8c..d52c2f6 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -22,12 +22,11 @@ use std::{fmt, io}; /// Owned read half of a [`TcpStream`], created by [`into_split`]. /// /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found -/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// on the [`AsyncReadExt`] trait. /// /// [`TcpStream`]: TcpStream /// [`into_split`]: TcpStream::into_split() /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct OwnedReadHalf { inner: Arc<TcpStream>, @@ -40,14 +39,13 @@ pub struct OwnedReadHalf { /// will also shut down the write half of the TCP stream. /// /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found -/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// on the [`AsyncWriteExt`] trait. /// /// [`TcpStream`]: TcpStream /// [`into_split`]: TcpStream::into_split() /// [`AsyncWrite`]: trait@crate::io::AsyncWrite /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct OwnedWriteHalf { inner: Arc<TcpStream>, @@ -110,12 +108,16 @@ impl OwnedReadHalf { /// the queue, registering the current task for wakeup if data is not yet /// available. /// + /// Note that on multiple calls to `poll_peek` or `poll_read`, only the + /// `Waker` from the `Context` passed to the most recent call is scheduled + /// to receive a wakeup. + /// /// See the [`TcpStream::poll_peek`] level documenation for more details. /// /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -125,6 +127,7 @@ impl OwnedReadHalf { /// let stream = TcpStream::connect("127.0.0.1:8000").await?; /// let (mut read_half, _) = stream.into_split(); /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// read_half.poll_peek(cx, &mut buf) @@ -135,7 +138,11 @@ impl OwnedReadHalf { /// ``` /// /// [`TcpStream::poll_peek`]: TcpStream::poll_peek - pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + pub fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { self.inner.poll_peek(cx, buf) } @@ -151,7 +158,7 @@ impl OwnedReadHalf { /// /// ```no_run /// use tokio::net::TcpStream; - /// use tokio::prelude::*; + /// use tokio::io::AsyncReadExt; /// use std::error::Error; /// /// #[tokio::main] @@ -179,7 +186,8 @@ impl OwnedReadHalf { /// [`read`]: fn@crate::io::AsyncReadExt::read /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { - poll_fn(|cx| self.poll_peek(cx, buf)).await + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } } @@ -215,7 +223,7 @@ impl OwnedWriteHalf { impl Drop for OwnedWriteHalf { fn drop(&mut self) { if self.shutdown_on_drop { - let _ = self.inner.shutdown(Shutdown::Write); + let _ = self.inner.shutdown_std(Shutdown::Write); } } } @@ -249,7 +257,7 @@ impl AsyncWrite for OwnedWriteHalf { // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - let res = self.inner.shutdown(Shutdown::Write); + let res = self.inner.shutdown_std(Shutdown::Write); if res.is_ok() { Pin::into_inner(self).shutdown_on_drop = false; } diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 28118f7..d4bfba4 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -9,10 +9,10 @@ use std::fmt; use std::io; use std::net::{Shutdown, SocketAddr}; #[cfg(windows)] -use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket}; #[cfg(unix)] -use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; @@ -25,20 +25,19 @@ cfg_net! { /// /// Reading and writing to a `TcpStream` is usually done using the /// convenience methods found on the [`AsyncReadExt`] and [`AsyncWriteExt`] - /// traits. Examples import these traits through [the prelude]. + /// traits. /// /// [`connect`]: method@TcpStream::connect /// [accepting]: method@crate::net::TcpListener::accept /// [listener]: struct@crate::net::TcpListener /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - /// [the prelude]: crate::prelude /// /// # Examples /// /// ```no_run /// use tokio::net::TcpStream; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// use std::error::Error; /// /// #[tokio::main] @@ -57,6 +56,13 @@ cfg_net! { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt + /// + /// To shut down the stream in the write direction, you can call the + /// [`shutdown()`] method. This will cause the other peer to receive a read of + /// length 0, indicating that no more data will be sent. This only closes + /// the stream in one direction. + /// + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown pub struct TcpStream { io: PollEvented<mio::net::TcpStream>, } @@ -81,7 +87,7 @@ impl TcpStream { /// /// ```no_run /// use tokio::net::TcpStream; - /// use tokio::prelude::*; + /// use tokio::io::AsyncWriteExt; /// use std::error::Error; /// /// #[tokio::main] @@ -184,6 +190,58 @@ impl TcpStream { Ok(TcpStream { io }) } + /// Turn a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. + /// + /// The returned [`std::net::TcpStream`] will have `nonblocking mode` set as `true`. + /// Use [`set_nonblocking`] to change the blocking mode if needed. + /// + /// # Examples + /// + /// ``` + /// use std::error::Error; + /// use std::io::Read; + /// use tokio::net::TcpListener; + /// # use tokio::net::TcpStream; + /// # use tokio::io::AsyncWriteExt; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let mut data = [0u8; 12]; + /// let listener = TcpListener::bind("127.0.0.1:34254").await?; + /// # let handle = tokio::spawn(async { + /// # let mut stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + /// # stream.write(b"Hello world!").await.unwrap(); + /// # }); + /// let (tokio_tcp_stream, _) = listener.accept().await?; + /// let mut std_tcp_stream = tokio_tcp_stream.into_std()?; + /// # handle.await.expect("The task being joined has panicked"); + /// std_tcp_stream.set_nonblocking(false)?; + /// std_tcp_stream.read_exact(&mut data)?; + /// # assert_eq!(b"Hello world!", &data); + /// Ok(()) + /// } + /// ``` + /// [`tokio::net::TcpStream`]: TcpStream + /// [`std::net::TcpStream`]: std::net::TcpStream + /// [`set_nonblocking`]: fn@std::net::TcpStream::set_nonblocking + pub fn into_std(self) -> io::Result<std::net::TcpStream> { + #[cfg(unix)] + { + self.io + .into_inner() + .map(|io| io.into_raw_fd()) + .map(|raw_fd| unsafe { std::net::TcpStream::from_raw_fd(raw_fd) }) + } + + #[cfg(windows)] + { + self.io + .into_inner() + .map(|io| io.into_raw_socket()) + .map(|raw_socket| unsafe { std::net::TcpStream::from_raw_socket(raw_socket) }) + } + } + /// Returns the local address that this stream is bound to. /// /// # Examples @@ -224,6 +282,11 @@ impl TcpStream { /// the queue, registering the current task for wakeup if data is not yet /// available. /// + /// Note that on multiple calls to `poll_peek`, `poll_read` or + /// `poll_read_ready`, only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. (However, + /// `poll_write` retains a second, independent waker.) + /// /// # Return value /// /// The function returns: @@ -239,7 +302,7 @@ impl TcpStream { /// # Examples /// /// ```no_run - /// use tokio::io; + /// use tokio::io::{self, ReadBuf}; /// use tokio::net::TcpStream; /// /// use futures::future::poll_fn; @@ -248,6 +311,7 @@ impl TcpStream { /// async fn main() -> io::Result<()> { /// let stream = TcpStream::connect("127.0.0.1:8000").await?; /// let mut buf = [0; 10]; + /// let mut buf = ReadBuf::new(&mut buf); /// /// poll_fn(|cx| { /// stream.poll_peek(cx, &mut buf) @@ -256,12 +320,24 @@ impl TcpStream { /// Ok(()) /// } /// ``` - pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> { + pub fn poll_peek( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<usize>> { loop { let ev = ready!(self.io.registration().poll_read_ready(cx))?; - match self.io.peek(buf) { - Ok(ret) => return Poll::Ready(Ok(ret)), + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + match self.io.peek(b) { + Ok(ret) => { + unsafe { buf.assume_init(ret) }; + buf.advance(ret); + return Poll::Ready(Ok(ret)); + } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { self.io.registration().clear_readiness(ev); } @@ -285,6 +361,7 @@ impl TcpStream { /// use tokio::io::Interest; /// use tokio::net::TcpStream; /// use std::error::Error; + /// use std::io; /// /// #[tokio::main] /// async fn main() -> Result<(), Box<dyn Error>> { @@ -294,17 +371,37 @@ impl TcpStream { /// let ready = stream.ready(Interest::READABLE | Interest::WRITABLE).await?; /// /// if ready.is_readable() { - /// // The buffer is **not** included in the async task and will only exist - /// // on the stack. - /// let mut data = [0; 1024]; - /// let n = stream.try_read(&mut data[..]).unwrap(); + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } /// - /// println!("GOT {:?}", &data[..n]); /// } /// /// if ready.is_writable() { - /// // Write some data - /// stream.try_write(b"hello world").unwrap(); + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } /// } /// } /// } @@ -316,7 +413,7 @@ impl TcpStream { /// Wait for the socket to become readable. /// - /// This function is equivalent to `ready(Interest::READABLE)` is usually + /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_read()`. /// /// # Examples @@ -364,10 +461,32 @@ impl TcpStream { /// Polls for read readiness. /// + /// If the tcp stream is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the tcp + /// stream becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_read_ready`, `poll_read` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_write_ready` retains a second, independent waker.) + /// /// This function is intended for cases where creating and pinning a future /// via [`readable`] is not feasible. Where possible, using [`readable`] is /// preferred, as this supports polling from multiple tasks at once. /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the tcp stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the tcp stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// /// [`readable`]: method@Self::readable pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.io.registration().poll_read_ready(cx).map_ok(|_| ()) @@ -389,7 +508,7 @@ impl TcpStream { /// # Return /// /// If data is successfully read, `Ok(n)` is returned, where `n` is the - /// number of bytes read. `Ok(n)` indicates the stream's read half is closed + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed /// and will no longer yield data. If the stream is not ready to read data /// `Err(io::ErrorKind::WouldBlock)` is returned. /// @@ -442,7 +561,7 @@ impl TcpStream { /// Wait for the socket to become writable. /// - /// This function is equivalent to `ready(Interest::WRITABLE)` is usually + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually /// paired with `try_write()`. /// /// # Examples @@ -486,10 +605,32 @@ impl TcpStream { /// Polls for write readiness. /// + /// If the tcp stream is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the tcp + /// stream becomes ready for writing, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// /// This function is intended for cases where creating and pinning a future /// via [`writable`] is not feasible. Where possible, using [`writable`] is /// preferred, as this supports polling from multiple tasks at once. /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the tcp stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the tcp stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// /// [`writable`]: method@Self::writable pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.io.registration().poll_write_ready(cx).map_ok(|_| ()) @@ -562,7 +703,7 @@ impl TcpStream { /// /// ```no_run /// use tokio::net::TcpStream; - /// use tokio::prelude::*; + /// use tokio::io::AsyncReadExt; /// use std::error::Error; /// /// #[tokio::main] @@ -600,26 +741,7 @@ impl TcpStream { /// This function will cause all pending and future I/O on the specified /// portions to return immediately with an appropriate value (see the /// documentation of `Shutdown`). - /// - /// # Examples - /// - /// ```no_run - /// use tokio::net::TcpStream; - /// use std::error::Error; - /// use std::net::Shutdown; - /// - /// #[tokio::main] - /// async fn main() -> Result<(), Box<dyn Error>> { - /// // Connect to a peer - /// let stream = TcpStream::connect("127.0.0.1:8080").await?; - /// - /// // Shutdown the stream - /// stream.shutdown(Shutdown::Write)?; - /// - /// Ok(()) - /// } - /// ``` - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> { self.io.shutdown(how) } @@ -797,25 +919,14 @@ impl TcpStream { /// this comes at the cost of a heap allocation. /// /// **Note:** Dropping the write half will shut down the write half of the TCP - /// stream. This is equivalent to calling [`shutdown(Write)`] on the `TcpStream`. + /// stream. This is equivalent to calling [`shutdown()`] on the `TcpStream`. /// /// [`split`]: TcpStream::split() - /// [`shutdown(Write)`]: fn@crate::net::TcpStream::shutdown + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { split_owned(self) } - // == Poll IO functions that takes `&self` == - // - // They are not public because (taken from the doc of `PollEvented`): - // - // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the - // caller must ensure that there are at most two tasks that use a - // `PollEvented` instance concurrently. One for reading and one for writing. - // While violating this requirement is "safe" from a Rust memory model point - // of view, it will result in unexpected behavior in the form of lost - // notifications and tasks hanging. - pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, @@ -894,7 +1005,7 @@ impl AsyncWrite for TcpStream { } fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - self.shutdown(std::net::Shutdown::Write)?; + self.shutdown_std(std::net::Shutdown::Write)?; Poll::Ready(Ok(())) } } diff --git a/src/net/udp/socket.rs b/src/net/udp.rs index a9c5c86..23abe98 100644 --- a/src/net/udp/socket.rs +++ b/src/net/udp.rs @@ -23,10 +23,11 @@ cfg_net! { /// /// # Streams /// - /// If you need to listen over UDP and produce a [`Stream`](`crate::stream::Stream`), you can look + /// If you need to listen over UDP and produce a [`Stream`], you can look /// at [`UdpFramed`]. /// /// [`UdpFramed`]: https://docs.rs/tokio-util/latest/tokio_util/udp/struct.UdpFramed.html + /// [`Stream`]: https://docs.rs/futures/0.3/futures/stream/trait.Stream.html /// /// # Example: one to many (bind) /// @@ -745,11 +746,11 @@ impl UdpSocket { &self, cx: &mut Context<'_>, buf: &[u8], - target: &SocketAddr, + target: SocketAddr, ) -> Poll<io::Result<usize>> { self.io .registration() - .poll_write_io(cx, || self.io.send_to(buf, *target)) + .poll_write_io(cx, || self.io.send_to(buf, target)) } /// Try to send data on the socket to the given address, but if the send is @@ -915,8 +916,8 @@ impl UdpSocket { /// /// // Try to recv data, this may still fail with `WouldBlock` /// // if the readiness event is a false positive. - /// match socket.try_recv(&mut buf) { - /// Ok(n) => { + /// match socket.try_recv_from(&mut buf) { + /// Ok((n, _addr)) => { /// println!("GOT {:?}", &buf[..n]); /// break; /// } diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs deleted file mode 100644 index c9bb0f8..0000000 --- a/src/net/udp/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! UDP utility types. - -pub(crate) mod socket; diff --git a/src/net/unix/datagram/socket.rs b/src/net/unix/datagram/socket.rs index f9e9321..fb5f602 100644 --- a/src/net/unix/datagram/socket.rs +++ b/src/net/unix/datagram/socket.rs @@ -1,4 +1,4 @@ -use crate::io::{Interest, PollEvented}; +use crate::io::{Interest, PollEvented, ReadBuf, Ready}; use crate::net::unix::SocketAddr; use std::convert::TryFrom; @@ -8,6 +8,7 @@ use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net; use std::path::Path; +use std::task::{Context, Poll}; cfg_net_unix! { /// An I/O object representing a Unix datagram socket. @@ -83,6 +84,178 @@ cfg_net_unix! { } impl UnixDatagram { + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_recv()` or `try_send()`. It + /// can be used to concurrently recv / send to the same socket on a single + /// task without splitting the socket. + /// + /// The function may complete without the socket being ready. This is a + /// false-positive and attempting an operation will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Examples + /// + /// Concurrently receive from and send to the socket on the same task + /// without splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// let ready = socket.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = [0; 1024]; + /// match socket.try_recv(&mut data[..]) { + /// Ok(n) => { + /// println!("received {:?}", &data[..n]); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Write some data + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// println!("sent {} bytes", n); + /// } + /// // False-positive, continue + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Wait for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is + /// usually paired with `try_send()` or `try_send_to()`. + /// + /// The function may complete without the socket being writable. This is a + /// false-positive and attempting a `try_send()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Wait for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_recv()`. + /// + /// The function may complete without the socket being readable. This is a + /// false-positive and attempting a `try_recv()` will return with + /// `io::ErrorKind::WouldBlock`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + /// Creates a new `UnixDatagram` bound to the specified path. /// /// # Examples @@ -309,68 +482,91 @@ impl UnixDatagram { /// Try to send a datagram to the peer without waiting. /// /// # Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { - /// use tokio::net::UnixDatagram; /// - /// let bytes = b"bytes"; - /// // We use a socket pair so that they are assigned - /// // each other as a peer. - /// let (first, second) = UnixDatagram::pair()?; - /// - /// let size = first.try_send(bytes)?; - /// assert_eq!(size, bytes.len()); - /// - /// let mut buffer = vec![0u8; 24]; - /// let size = second.try_recv(&mut buffer)?; - /// - /// let dgram = &buffer[..size]; - /// assert_eq!(dgram, bytes); - /// # Ok(()) - /// # } + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } /// ``` pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> { - self.io.send(buf) + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send(buf)) } /// Try to send a datagram to the peer without waiting. /// /// # Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { - /// use tokio::net::UnixDatagram; - /// use tempfile::tempdir; - /// - /// let bytes = b"bytes"; - /// // We use a temporary directory so that the socket - /// // files left by the bound sockets will get cleaned up. - /// let tmp = tempdir().unwrap(); /// - /// let server_path = tmp.path().join("server"); - /// let server = UnixDatagram::bind(&server_path)?; - /// - /// let client_path = tmp.path().join("client"); - /// let client = UnixDatagram::bind(&client_path)?; - /// - /// let size = client.try_send_to(bytes, &server_path)?; - /// assert_eq!(size, bytes.len()); - /// - /// let mut buffer = vec![0u8; 24]; - /// let (size, addr) = server.try_recv_from(&mut buffer)?; - /// - /// let dgram = &buffer[..size]; - /// assert_eq!(dgram, bytes); - /// assert_eq!(addr.as_pathname().unwrap(), &client_path); - /// # Ok(()) - /// # } + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be writable + /// socket.writable().await?; + /// + /// // Try to send data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_send_to(b"hello world", &server_path) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } /// ``` pub fn try_send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize> where P: AsRef<Path>, { - self.io.send_to(buf, target) + self.io + .registration() + .try_io(Interest::WRITABLE, || self.io.send_to(buf, target)) } /// Receives data from the socket. @@ -409,29 +605,51 @@ impl UnixDatagram { /// Try to receive a datagram from the peer without waiting. /// /// # Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { - /// use tokio::net::UnixDatagram; - /// - /// let bytes = b"bytes"; - /// // We use a socket pair so that they are assigned - /// // each other as a peer. - /// let (first, second) = UnixDatagram::pair()?; - /// - /// let size = first.try_send(bytes)?; - /// assert_eq!(size, bytes.len()); - /// - /// let mut buffer = vec![0u8; 24]; - /// let size = second.try_recv(&mut buffer)?; /// - /// let dgram = &buffer[..size]; - /// assert_eq!(dgram, bytes); - /// # Ok(()) - /// # } + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// socket.connect(&server_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv(&mut buf) { + /// Ok(n) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } /// ``` pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> { - self.io.recv(buf) + self.io + .registration() + .try_io(Interest::READABLE, || self.io.recv(buf)) } /// Sends data on the socket to the specified address. @@ -520,40 +738,195 @@ impl UnixDatagram { Ok((n, SocketAddr(addr))) } - /// Try to receive data from the socket without waiting. + /// Attempts to receive a single datagram on the specified address. /// - /// # Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { - /// use tokio::net::UnixDatagram; - /// use tempfile::tempdir; + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. /// - /// let bytes = b"bytes"; - /// // We use a temporary directory so that the socket - /// // files left by the bound sockets will get cleaned up. - /// let tmp = tempdir().unwrap(); + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(addr))` reads data from `addr` into `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. /// - /// let server_path = tmp.path().join("server"); - /// let server = UnixDatagram::bind(&server_path)?; + /// # Errors /// - /// let client_path = tmp.path().join("client"); - /// let client = UnixDatagram::bind(&client_path)?; + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<SocketAddr>> { + let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv_from(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(SocketAddr(addr))) + } + + /// Attempts to send data to the specified address. /// - /// let size = client.try_send_to(bytes, &server_path)?; - /// assert_eq!(size, bytes.len()); + /// Note that on multiple calls to a `poll_*` method in the send direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. /// - /// let mut buffer = vec![0u8; 24]; - /// let (size, addr) = server.try_recv_from(&mut buffer)?; + /// # Return value /// - /// let dgram = &buffer[..size]; - /// assert_eq!(dgram, bytes); - /// assert_eq!(addr.as_pathname().unwrap(), &client_path); - /// # Ok(()) - /// # } + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + pub fn poll_send_to<P>( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: P, + ) -> Poll<io::Result<usize>> + where + P: AsRef<Path>, + { + self.io + .registration() + .poll_write_io(cx, || self.io.send_to(buf, target.as_ref())) + } + + /// Attempts to send data on the socket to the remote address to which it + /// was previously `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the send direction, + /// only the `Waker` from the `Context` passed to the most recent call will + /// be scheduled to receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not available to write + /// * `Poll::Ready(Ok(n))` `n` is the number of bytes sent + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> { + self.io + .registration() + .poll_write_io(cx, || self.io.send(buf)) + } + + /// Attempts to receive a single datagram message on the socket from the remote + /// address to which it is `connect`ed. + /// + /// The [`connect`] method will connect this socket to a remote address. This method + /// resolves to an error if the socket is not connected. + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready to read + /// * `Poll::Ready(Ok(()))` reads data `ReadBuf` if the socket is ready + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`connect`]: method@Self::connect + pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { + let n = ready!(self.io.registration().poll_read_io(cx, || { + // Safety: will not read the maybe uinitialized bytes. + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) + }; + + self.io.recv(b) + }))?; + + // Safety: We trust `recv` to have filled up `n` bytes in the buffer. + unsafe { + buf.assume_init(n); + } + buf.advance(n); + Poll::Ready(Ok(())) + } + + /// Try to receive data from the socket without waiting. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixDatagram; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let client_path = dir.path().join("client.sock"); + /// let server_path = dir.path().join("server.sock"); + /// let socket = UnixDatagram::bind(&client_path)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// socket.readable().await?; + /// + /// // The buffer is **not** included in the async task and will + /// // only exist on the stack. + /// let mut buf = [0; 1024]; + /// + /// // Try to recv data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match socket.try_recv_from(&mut buf) { + /// Ok((n, _addr)) => { + /// println!("GOT {:?}", &buf[..n]); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e); + /// } + /// } + /// } + /// + /// Ok(()) + /// } /// ``` pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - let (n, addr) = self.io.recv_from(buf)?; + let (n, addr) = self + .io + .registration() + .try_io(Interest::READABLE, || self.io.recv_from(buf))?; + Ok((n, SocketAddr(addr))) } diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs index b1da0e3..9ed4ce1 100644 --- a/src/net/unix/listener.rs +++ b/src/net/unix/listener.rs @@ -12,10 +12,7 @@ use std::task::{Context, Poll}; cfg_net_unix! { /// A Unix socket which can accept connections from other Unix sockets. /// - /// You can accept a new connection by using the [`accept`](`UnixListener::accept`) method. Alternatively `UnixListener` - /// implements the [`Stream`](`crate::stream::Stream`) trait, which allows you to use the listener in places that want a - /// stream. The stream will never return `None` and will also not yield the peer's `SocketAddr` structure. Iterating over - /// it is equivalent to calling accept in a loop. + /// You can accept a new connection by using the [`accept`](`UnixListener::accept`) method. /// /// # Errors /// @@ -29,14 +26,13 @@ cfg_net_unix! { /// /// ```no_run /// use tokio::net::UnixListener; - /// use tokio::stream::StreamExt; /// /// #[tokio::main] /// async fn main() { - /// let mut listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// while let Some(stream) = listener.next().await { - /// match stream { - /// Ok(stream) => { + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// loop { + /// match listener.accept().await { + /// Ok((stream, _addr)) => { /// println!("new client!"); /// } /// Err(e) => { /* connection failed */ } @@ -113,12 +109,10 @@ impl UnixListener { /// Polls to accept a new incoming connection to this listener. /// - /// If there is no connection to accept, `Poll::Pending` is returned and - /// the current task will be notified by a waker. - /// - /// When ready, the most recent task that called `poll_accept` is notified. - /// The caller is responsible to ensure that `poll_accept` is called from a - /// single task. Failing to do this could result in tasks hanging. + /// If there is no connection to accept, `Poll::Pending` is returned and the + /// current task will be notified by a waker. Note that on multiple calls + /// to `poll_accept`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UnixStream, SocketAddr)>> { let (sock, addr) = ready!(self.io.registration().poll_read_io(cx, || self.io.accept()))?; let addr = SocketAddr(addr); @@ -127,16 +121,6 @@ impl UnixListener { } } -#[cfg(feature = "stream")] -impl crate::stream::Stream for UnixListener { - type Item = io::Result<UnixStream>; - - fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let (socket, _) = ready!(self.poll_accept(cx))?; - Poll::Ready(Some(Ok(socket))) - } -} - impl TryFrom<std::os::unix::net::UnixListener> for UnixListener { type Error = io::Error; diff --git a/src/net/unix/split.rs b/src/net/unix/split.rs index af9c762..24a711b 100644 --- a/src/net/unix/split.rs +++ b/src/net/unix/split.rs @@ -19,12 +19,11 @@ use std::task::{Context, Poll}; /// Borrowed read half of a [`UnixStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the -/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// [`AsyncReadExt`] trait. /// /// [`UnixStream`]: UnixStream /// [`split`]: UnixStream::split() /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct ReadHalf<'a>(&'a UnixStream); @@ -34,14 +33,13 @@ pub struct ReadHalf<'a>(&'a UnixStream); /// shut down the UnixStream stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found -/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude]. +/// on the [`AsyncWriteExt`] trait. /// /// [`UnixStream`]: UnixStream /// [`split`]: UnixStream::split() /// [`AsyncWrite`]: trait@crate::io::AsyncWrite /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct WriteHalf<'a>(&'a UnixStream); @@ -85,7 +83,7 @@ impl AsyncWrite for WriteHalf<'_> { } fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - self.0.shutdown(Shutdown::Write).into() + self.0.shutdown_std(Shutdown::Write).into() } } diff --git a/src/net/unix/split_owned.rs b/src/net/unix/split_owned.rs index 5f0a259..3d6ac6a 100644 --- a/src/net/unix/split_owned.rs +++ b/src/net/unix/split_owned.rs @@ -21,12 +21,11 @@ use std::{fmt, io}; /// Owned read half of a [`UnixStream`], created by [`into_split`]. /// /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found -/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude]. +/// on the [`AsyncReadExt`] trait. /// /// [`UnixStream`]: crate::net::UnixStream /// [`into_split`]: crate::net::UnixStream::into_split() /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct OwnedReadHalf { inner: Arc<UnixStream>, @@ -39,15 +38,13 @@ pub struct OwnedReadHalf { /// Dropping the write half will also shut down the write half of the stream. /// /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods -/// found on the [`AsyncWriteExt`] trait. Examples import this trait through -/// [the prelude]. +/// found on the [`AsyncWriteExt`] trait. /// /// [`UnixStream`]: crate::net::UnixStream /// [`into_split`]: crate::net::UnixStream::into_split() /// [`AsyncWrite`]: trait@crate::io::AsyncWrite /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt -/// [the prelude]: crate::prelude #[derive(Debug)] pub struct OwnedWriteHalf { inner: Arc<UnixStream>, @@ -139,7 +136,7 @@ impl OwnedWriteHalf { impl Drop for OwnedWriteHalf { fn drop(&mut self) { if self.shutdown_on_drop { - let _ = self.inner.shutdown(Shutdown::Write); + let _ = self.inner.shutdown_std(Shutdown::Write); } } } @@ -173,7 +170,7 @@ impl AsyncWrite for OwnedWriteHalf { // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - let res = self.inner.shutdown(Shutdown::Write); + let res = self.inner.shutdown_std(Shutdown::Write); if res.is_ok() { Pin::into_inner(self).shutdown_on_drop = false; } diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs index f961994..dc929dc 100644 --- a/src/net/unix/stream.rs +++ b/src/net/unix/stream.rs @@ -1,5 +1,5 @@ use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; use crate::net::unix::split::{split, ReadHalf, WriteHalf}; use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf}; use crate::net::unix::ucred::{self, UCred}; @@ -7,7 +7,7 @@ use crate::net::unix::SocketAddr; use std::convert::TryFrom; use std::fmt; -use std::io; +use std::io::{self, Read, Write}; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net; @@ -21,6 +21,13 @@ cfg_net_unix! { /// This socket can be connected directly with `UnixStream::connect` or accepted /// from a listener with `UnixListener::incoming`. Additionally, a pair of /// anonymous Unix sockets can be created with `UnixStream::pair`. + /// + /// To shut down the stream in the write direction, you can call the + /// [`shutdown()`] method. This will cause the other peer to receive a read of + /// length 0, indicating that no more data will be sent. This only closes + /// the stream in one direction. + /// + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown pub struct UnixStream { io: PollEvented<mio::net::UnixStream>, } @@ -43,6 +50,358 @@ impl UnixStream { Ok(stream) } + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Examples + /// + /// Concurrently read and write to the stream on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// let ready = stream.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Wait for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the unix stream is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the unix + /// stream becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the unix stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the unix stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Try to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Wait for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the unix stream is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the unix + /// stream becomes ready for writing, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the unix stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the unix stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Try to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`. /// /// This function is intended to be used to wrap a UnixStream from the @@ -107,7 +466,7 @@ impl UnixStream { /// This function will cause all pending and future I/O calls on the /// specified portions to immediately return with an appropriate value /// (see the documentation of `Shutdown`). - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> { self.io.shutdown(how) } @@ -132,10 +491,10 @@ impl UnixStream { /// this comes at the cost of a heap allocation. /// /// **Note:** Dropping the write half will shut down the write half of the - /// stream. This is equivalent to calling [`shutdown(Write)`] on the `UnixStream`. + /// stream. This is equivalent to calling [`shutdown()`] on the `UnixStream`. /// /// [`split`]: Self::split() - /// [`shutdown(Write)`]: fn@Self::shutdown + /// [`shutdown()`]: fn@crate::io::AsyncWriteExt::shutdown pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { split_owned(self) } @@ -189,7 +548,7 @@ impl AsyncWrite for UnixStream { } fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { - self.shutdown(std::net::Shutdown::Write)?; + self.shutdown_std(std::net::Shutdown::Write)?; Poll::Ready(Ok(())) } } diff --git a/src/prelude.rs b/src/prelude.rs deleted file mode 100644 index 1909f9d..0000000 --- a/src/prelude.rs +++ /dev/null @@ -1,21 +0,0 @@ -#![cfg(not(loom))] - -//! A "prelude" for users of the `tokio` crate. -//! -//! This prelude is similar to the standard library's prelude in that you'll -//! almost always want to import its entire contents, but unlike the standard -//! library's prelude you'll have to do so manually: -//! -//! ``` -//! # #![allow(warnings)] -//! use tokio::prelude::*; -//! ``` -//! -//! The prelude may grow over time as additional items see ubiquitous use. - -pub use crate::io::{self, AsyncBufRead, AsyncRead, AsyncWrite}; - -cfg_io_util! { - #[doc(no_inline)] - pub use crate::io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _}; -} diff --git a/src/process/unix/mod.rs b/src/process/unix/mod.rs index 966c2a2..3608b9f 100644 --- a/src/process/unix/mod.rs +++ b/src/process/unix/mod.rs @@ -36,6 +36,7 @@ use crate::signal::unix::{signal, Signal, SignalKind}; use mio::event::Source; use mio::unix::SourceFd; +use once_cell::sync::Lazy; use std::fmt; use std::fs::File; use std::future::Future; @@ -62,9 +63,7 @@ impl Kill for StdChild { } } -lazy_static::lazy_static! { - static ref ORPHAN_QUEUE: OrphanQueueImpl<StdChild> = OrphanQueueImpl::new(); -} +static ORPHAN_QUEUE: Lazy<OrphanQueueImpl<StdChild>> = Lazy::new(OrphanQueueImpl::new); pub(crate) struct GlobalOrphanQueue; diff --git a/src/runtime/blocking/shutdown.rs b/src/runtime/blocking/shutdown.rs index 3b6cc59..0cf2285 100644 --- a/src/runtime/blocking/shutdown.rs +++ b/src/runtime/blocking/shutdown.rs @@ -38,7 +38,7 @@ impl Receiver { use crate::runtime::enter::try_enter; if timeout == Some(Duration::from_nanos(0)) { - return true; + return false; } let mut e = match try_enter(false) { diff --git a/src/runtime/builder.rs b/src/runtime/builder.rs index e792c7d..1f8892e 100644 --- a/src/runtime/builder.rs +++ b/src/runtime/builder.rs @@ -53,7 +53,7 @@ pub struct Builder { worker_threads: Option<usize>, /// Cap on thread usage. - max_threads: usize, + max_blocking_threads: usize, /// Name fn used for threads spawned by the runtime. pub(super) thread_name: ThreadNameFn, @@ -113,7 +113,7 @@ impl Builder { // Default to lazy auto-detection (one thread per CPU core) worker_threads: None, - max_threads: 512, + max_blocking_threads: 512, // Default thread name thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()), @@ -209,22 +209,22 @@ impl Builder { self } - /// Specifies limit for threads, spawned by the Runtime. + /// Specifies limit for threads spawned by the Runtime used for blocking operations. /// - /// This is number of threads to be used by Runtime, including `core_threads` - /// Having `max_threads` less than `worker_threads` results in invalid configuration - /// when building multi-threaded `Runtime`, which would cause a panic. /// - /// Similarly to the `worker_threads`, this number should be between 0 and 32,768. + /// Similarly to the `worker_threads`, this number should be between 1 and 32,768. /// /// The default value is 512. /// - /// When multi-threaded runtime is not used, will act as limit on additional threads. + /// Otherwise as `worker_threads` are always active, it limits additional threads (e.g. for + /// blocking annotations). /// - /// Otherwise as `core_threads` are always active, it limits additional threads (e.g. for - /// blocking annotations) as `max_threads - core_threads`. - pub fn max_threads(&mut self, val: usize) -> &mut Self { - self.max_threads = val; + /// # Panic + /// + /// This will panic if `val` is not larger than `0`. + pub fn max_blocking_threads(&mut self, val: usize) -> &mut Self { + assert!(val > 0, "Max blocking threads cannot be set to 0"); + self.max_blocking_threads = val; self } @@ -379,6 +379,11 @@ impl Builder { fn get_cfg(&self) -> driver::Cfg { driver::Cfg { + enable_pause_time: match self.kind { + Kind::CurrentThread => true, + #[cfg(feature = "rt-multi-thread")] + Kind::MultiThread => false, + }, enable_io: self.enable_io, enable_time: self.enable_time, } @@ -419,7 +424,7 @@ impl Builder { let spawner = Spawner::Basic(scheduler.spawner().clone()); // Blocking pool - let blocking_pool = blocking::create_blocking_pool(self, self.max_threads); + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); let blocking_spawner = blocking_pool.spawner().clone(); Ok(Runtime { @@ -490,10 +495,8 @@ cfg_rt_multi_thread! { use crate::loom::sys::num_cpus; use crate::runtime::{Kind, ThreadPool}; use crate::runtime::park::Parker; - use std::cmp; - let core_threads = self.worker_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus())); - assert!(core_threads <= self.max_threads, "Core threads number cannot be above max limit"); + let core_threads = self.worker_threads.unwrap_or_else(num_cpus); let (driver, resources) = driver::Driver::new(self.get_cfg())?; @@ -501,7 +504,7 @@ cfg_rt_multi_thread! { let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); // Create the blocking pool - let blocking_pool = blocking::create_blocking_pool(self, self.max_threads); + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads + core_threads); let blocking_spawner = blocking_pool.spawner().clone(); // Create the runtime handle @@ -531,7 +534,7 @@ impl fmt::Debug for Builder { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Builder") .field("worker_threads", &self.worker_threads) - .field("max_threads", &self.max_threads) + .field("max_blocking_threads", &self.max_blocking_threads) .field( "thread_name", &"<dyn Fn() -> String + Send + Sync + 'static>", diff --git a/src/runtime/driver.rs b/src/runtime/driver.rs index e89de9d..b89fa4f 100644 --- a/src/runtime/driver.rs +++ b/src/runtime/driver.rs @@ -103,8 +103,8 @@ cfg_time! { pub(crate) type Clock = crate::time::Clock; pub(crate) type TimeHandle = Option<crate::time::driver::Handle>; - fn create_clock() -> Clock { - crate::time::Clock::new() + fn create_clock(enable_pausing: bool) -> Clock { + crate::time::Clock::new(enable_pausing) } fn create_time_driver( @@ -131,7 +131,7 @@ cfg_not_time! { pub(crate) type Clock = (); pub(crate) type TimeHandle = (); - fn create_clock() -> Clock { + fn create_clock(_enable_pausing: bool) -> Clock { () } @@ -161,13 +161,14 @@ pub(crate) struct Resources { pub(crate) struct Cfg { pub(crate) enable_io: bool, pub(crate) enable_time: bool, + pub(crate) enable_pause_time: bool, } impl Driver { pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Resources)> { let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io)?; - let clock = create_clock(); + let clock = create_clock(cfg.enable_pause_time); let (time_driver, time_handle) = create_time_driver(cfg.enable_time, io_stack, clock.clone()); diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs index 138d13b..6ff3c39 100644 --- a/src/runtime/handle.rs +++ b/src/runtime/handle.rs @@ -142,7 +142,7 @@ impl Handle { F: Future + Send + 'static, F::Output: Send + 'static, { - #[cfg(feature = "tracing")] + #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task"); self.spawner.spawn(future) } @@ -172,7 +172,7 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { - #[cfg(feature = "tracing")] + #[cfg(all(tokio_unstable, feature = "tracing"))] let func = { #[cfg(tokio_track_caller)] let location = std::panic::Location::caller(); diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index d7f068e..2c90acb 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -20,7 +20,7 @@ //! //! ```no_run //! use tokio::net::TcpListener; -//! use tokio::prelude::*; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; //! //! #[tokio::main] //! async fn main() -> Result<(), Box<dyn std::error::Error>> { @@ -63,7 +63,7 @@ //! //! ```no_run //! use tokio::net::TcpListener; -//! use tokio::prelude::*; +//! use tokio::io::{AsyncReadExt, AsyncWriteExt}; //! use tokio::runtime::Runtime; //! //! fn main() -> Result<(), Box<dyn std::error::Error>> { diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs index bc544c9..31712e4 100644 --- a/src/runtime/thread_pool/worker.rs +++ b/src/runtime/thread_pool/worker.rs @@ -78,11 +78,12 @@ pub(super) struct Shared { /// Coordinates idle workers idle: Idle, - /// Workers have have observed the shutdown signal + /// Cores that have observed the shutdown signal /// /// The core is **not** placed back in the worker to avoid it from being /// stolen by a thread that was spawned as part of `block_in_place`. - shutdown_workers: Mutex<Vec<(Box<Core>, Arc<Worker>)>>, + #[allow(clippy::vec_box)] // we're moving an already-boxed value + shutdown_cores: Mutex<Vec<Box<Core>>>, } /// Used to communicate with a worker from other threads. @@ -157,7 +158,7 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) { remotes: remotes.into_boxed_slice(), inject: queue::Inject::new(), idle: Idle::new(size), - shutdown_workers: Mutex::new(vec![]), + shutdown_cores: Mutex::new(vec![]), }); let mut launch = Launch(vec![]); @@ -328,8 +329,10 @@ impl Context { } } + core.pre_shutdown(&self.worker); + // Signal shutdown - self.worker.shared.shutdown(core, self.worker.clone()); + self.worker.shared.shutdown(core); Err(()) } @@ -546,11 +549,9 @@ impl Core { } } - // Shutdown the core - fn shutdown(&mut self, worker: &Worker) { - // Take the core - let mut park = self.park.take().expect("park missing"); - + // Signals all tasks to shut down, and waits for them to complete. Must run + // before we enter the single-threaded phase of shutdown processing. + fn pre_shutdown(&mut self, worker: &Worker) { // Signal to all tasks to shut down. for header in self.tasks.iter() { header.shutdown(); @@ -564,8 +565,17 @@ impl Core { } // Wait until signalled + let park = self.park.as_mut().expect("park missing"); park.park().expect("park failed"); } + } + + // Shutdown the core + fn shutdown(&mut self) { + assert!(self.tasks.is_empty()); + + // Take the core + let mut park = self.park.take().expect("park missing"); // Drain the queue while self.next_local_task().is_some() {} @@ -629,52 +639,73 @@ impl task::Schedule for Arc<Worker> { fn release(&self, task: &Task) -> Option<Task> { use std::ptr::NonNull; - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); + enum Immediate { + // Task has been synchronously removed from the Core owned by the + // current thread + Removed(Option<Task>), + // Task is owned by another thread, so we need to notify it to clean + // up the task later. + MaybeRemote, + } - if self.eq(&cx.worker) { - let mut maybe_core = cx.core.borrow_mut(); + let immediate = CURRENT.with(|maybe_cx| { + let cx = match maybe_cx { + Some(cx) => cx, + None => return Immediate::MaybeRemote, + }; - if let Some(core) = &mut *maybe_core { - // Directly remove the task - // - // safety: the task is inserted in the list in `bind`. - unsafe { - let ptr = NonNull::from(task.header()); - return core.tasks.remove(ptr); - } - } + if !self.eq(&cx.worker) { + // Task owned by another core, so we need to notify it. + return Immediate::MaybeRemote; } - // Track the task to be released by the worker that owns it - // - // Safety: We get a new handle without incrementing the ref-count. - // A ref-count is held by the "owned" linked list and it is only - // ever removed from that list as part of the release process: this - // method or popping the task from `pending_drop`. Thus, we can rely - // on the ref-count held by the linked-list to keep the memory - // alive. - // - // When the task is removed from the stack, it is forgotten instead - // of dropped. - let task = unsafe { Task::from_raw(task.header().into()) }; - - self.remote().pending_drop.push(task); + let mut maybe_core = cx.core.borrow_mut(); - if cx.core.borrow().is_some() { - return None; + if let Some(core) = &mut *maybe_core { + // Directly remove the task + // + // safety: the task is inserted in the list in `bind`. + unsafe { + let ptr = NonNull::from(task.header()); + return Immediate::Removed(core.tasks.remove(ptr)); + } } - // The worker core has been handed off to another thread. In the - // event that the scheduler is currently shutting down, the thread - // that owns the task may be waiting on the release to complete - // shutdown. - if self.inject().is_closed() { - self.remote().unpark.unpark(); - } + Immediate::MaybeRemote + }); - None - }) + // Checks if we were called from within a worker, allowing for immediate + // removal of a scheduled task. Else we have to go through the slower + // process below where we remotely mark a task as dropped. + match immediate { + Immediate::Removed(task) => return task, + Immediate::MaybeRemote => (), + }; + + // Track the task to be released by the worker that owns it + // + // Safety: We get a new handle without incrementing the ref-count. + // A ref-count is held by the "owned" linked list and it is only + // ever removed from that list as part of the release process: this + // method or popping the task from `pending_drop`. Thus, we can rely + // on the ref-count held by the linked-list to keep the memory + // alive. + // + // When the task is removed from the stack, it is forgotten instead + // of dropped. + let task = unsafe { Task::from_raw(task.header().into()) }; + + self.remote().pending_drop.push(task); + + // The worker core has been handed off to another thread. In the + // event that the scheduler is currently shutting down, the thread + // that owns the task may be waiting on the release to complete + // shutdown. + if self.inject().is_closed() { + self.remote().unpark.unpark(); + } + + None } fn schedule(&self, task: Notified) { @@ -779,16 +810,16 @@ impl Shared { /// its core back into its handle. /// /// If all workers have reached this point, the final cleanup is performed. - fn shutdown(&self, core: Box<Core>, worker: Arc<Worker>) { - let mut workers = self.shutdown_workers.lock(); - workers.push((core, worker)); + fn shutdown(&self, core: Box<Core>) { + let mut cores = self.shutdown_cores.lock(); + cores.push(core); - if workers.len() != self.remotes.len() { + if cores.len() != self.remotes.len() { return; } - for (mut core, worker) in workers.drain(..) { - core.shutdown(&worker); + for mut core in cores.drain(..) { + core.shutdown(); } // Drain the injection queue diff --git a/src/signal/mod.rs b/src/signal/mod.rs index 6e5e350..d347e6e 100644 --- a/src/signal/mod.rs +++ b/src/signal/mod.rs @@ -5,7 +5,7 @@ //! signal handling, but it should be evaluated for your own applications' needs //! to see if it's suitable. //! -//! The are some fundamental limitations of this crate documented on the OS +//! There are some fundamental limitations of this crate documented on the OS //! specific structures, as well. //! //! # Examples diff --git a/src/signal/registry.rs b/src/signal/registry.rs index 5d6f608..55ee8c5 100644 --- a/src/signal/registry.rs +++ b/src/signal/registry.rs @@ -4,7 +4,7 @@ use crate::signal::os::{OsExtraData, OsStorage}; use crate::sync::mpsc::Sender; -use lazy_static::lazy_static; +use once_cell::sync::Lazy; use std::ops; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; @@ -165,12 +165,12 @@ where OsExtraData: 'static + Send + Sync + Init, OsStorage: 'static + Send + Sync + Init, { - lazy_static! { - static ref GLOBALS: Pin<Box<Globals>> = Box::pin(Globals { + static GLOBALS: Lazy<Pin<Box<Globals>>> = Lazy::new(|| { + Box::pin(Globals { extra: OsExtraData::init(), registry: Registry::new(OsStorage::init()), - }); - } + }) + }); GLOBALS.as_ref() } diff --git a/src/signal/unix.rs b/src/signal/unix.rs index aaaa75e..fc0f16d 100644 --- a/src/signal/unix.rs +++ b/src/signal/unix.rs @@ -407,16 +407,6 @@ impl Signal { } } -cfg_stream! { - impl crate::stream::Stream for Signal { - type Item = (); - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.poll_recv(cx) - } - } -} - // Work around for abstracting streams internally pub(crate) trait InternalStream: Unpin { fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>; diff --git a/src/signal/windows.rs b/src/signal/windows.rs index 1e78362..43af290 100644 --- a/src/signal/windows.rs +++ b/src/signal/windows.rs @@ -1,9 +1,9 @@ //! Windows-specific types for signal handling. //! -//! This module is only defined on Windows and contains the primary `Event` type -//! for receiving notifications of events. These events are listened for via the +//! This module is only defined on Windows and allows receiving "ctrl-c" +//! and "ctrl-break" notifications. These events are listened for via the //! `SetConsoleCtrlHandler` function which receives events of the type -//! `CTRL_C_EVENT` and `CTRL_BREAK_EVENT` +//! `CTRL_C_EVENT` and `CTRL_BREAK_EVENT`. #![cfg(windows)] @@ -79,10 +79,6 @@ pub(crate) struct Event { rx: Receiver<()>, } -pub(crate) fn ctrl_c() -> io::Result<Event> { - Event::new(CTRL_C_EVENT) -} - impl Event { fn new(signum: DWORD) -> io::Result<Self> { global_init()?; @@ -135,6 +131,106 @@ unsafe extern "system" fn handler(ty: DWORD) -> BOOL { } } +/// Creates a new stream which receives "ctrl-c" notifications sent to the +/// process. +/// +/// # Examples +/// +/// ```rust,no_run +/// use tokio::signal::windows::ctrl_c; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box<dyn std::error::Error>> { +/// // An infinite stream of CTRL-C events. +/// let mut stream = ctrl_c()?; +/// +/// // Print whenever a CTRL-C event is received. +/// for countdown in (0..3).rev() { +/// stream.recv().await; +/// println!("got CTRL-C. {} more to exit", countdown); +/// } +/// +/// Ok(()) +/// } +/// ``` +pub fn ctrl_c() -> io::Result<CtrlC> { + Event::new(CTRL_C_EVENT).map(|inner| CtrlC { inner }) +} + +/// Represents a stream which receives "ctrl-c" notifications sent to the process +/// via `SetConsoleCtrlHandler`. +/// +/// A notification to this process notifies *all* streams listening for +/// this event. Moreover, the notifications **are coalesced** if they aren't processed +/// quickly enough. This means that if two notifications are received back-to-back, +/// then the stream may only receive one item about the two notifications. +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct CtrlC { + inner: Event, +} + +impl CtrlC { + /// Receives the next signal notification event. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// ```rust,no_run + /// use tokio::signal::windows::ctrl_c; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn std::error::Error>> { + /// // An infinite stream of CTRL-C events. + /// let mut stream = ctrl_c()?; + /// + /// // Print whenever a CTRL-C event is received. + /// for countdown in (0..3).rev() { + /// stream.recv().await; + /// println!("got CTRL-C. {} more to exit", countdown); + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn recv(&mut self) -> Option<()> { + self.inner.recv().await + } + + /// Polls to receive the next signal notification event, outside of an + /// `async` context. + /// + /// `None` is returned if no more events can be received by this stream. + /// + /// # Examples + /// + /// Polling from a manually implemented future + /// + /// ```rust,no_run + /// use std::pin::Pin; + /// use std::future::Future; + /// use std::task::{Context, Poll}; + /// use tokio::signal::windows::CtrlC; + /// + /// struct MyFuture { + /// ctrl_c: CtrlC, + /// } + /// + /// impl Future for MyFuture { + /// type Output = Option<()>; + /// + /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + /// println!("polling MyFuture"); + /// self.ctrl_c.poll_recv(cx) + /// } + /// } + /// ``` + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { + self.inner.rx.poll_recv(cx) + } +} + /// Represents a stream which receives "ctrl-break" notifications sent to the process /// via `SetConsoleCtrlHandler`. /// @@ -163,7 +259,7 @@ impl CtrlBreak { /// // An infinite stream of CTRL-BREAK events. /// let mut stream = ctrl_break()?; /// - /// // Print whenever a CTRL-BREAK event is received + /// // Print whenever a CTRL-BREAK event is received. /// loop { /// stream.recv().await; /// println!("got signal CTRL-BREAK"); @@ -171,8 +267,7 @@ impl CtrlBreak { /// } /// ``` pub async fn recv(&mut self) -> Option<()> { - use crate::future::poll_fn; - poll_fn(|cx| self.poll_recv(cx)).await + self.inner.recv().await } /// Polls to receive the next signal notification event, outside of an @@ -208,16 +303,6 @@ impl CtrlBreak { } } -cfg_stream! { - impl crate::stream::Stream for CtrlBreak { - type Item = (); - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.poll_recv(cx) - } - } -} - /// Creates a new stream which receives "ctrl-break" notifications sent to the /// process. /// @@ -231,7 +316,7 @@ cfg_stream! { /// // An infinite stream of CTRL-BREAK events. /// let mut stream = ctrl_break()?; /// -/// // Print whenever a CTRL-BREAK event is received +/// // Print whenever a CTRL-BREAK event is received. /// loop { /// stream.recv().await; /// println!("got signal CTRL-BREAK"); @@ -246,7 +331,6 @@ pub fn ctrl_break() -> io::Result<CtrlBreak> { mod tests { use super::*; use crate::runtime::Runtime; - use crate::stream::StreamExt; use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; @@ -283,7 +367,7 @@ mod tests { super::handler(CTRL_BREAK_EVENT); } - ctrl_break.next().await.unwrap(); + ctrl_break.recv().await.unwrap(); }); } diff --git a/src/stream/all.rs b/src/stream/all.rs deleted file mode 100644 index 353d61a..0000000 --- a/src/stream/all.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::stream::Stream; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Future for the [`all`](super::StreamExt::all) method. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct AllFuture<'a, St: ?Sized, F> { - stream: &'a mut St, - f: F, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -impl<'a, St: ?Sized, F> AllFuture<'a, St, F> { - pub(super) fn new(stream: &'a mut St, f: F) -> Self { - Self { - stream, - f, - _pin: PhantomPinned, - } - } -} - -impl<St, F> Future for AllFuture<'_, St, F> -where - St: ?Sized + Stream + Unpin, - F: FnMut(St::Item) -> bool, -{ - type Output = bool; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = self.project(); - let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx)); - - match next { - Some(v) => { - if !(me.f)(v) { - Poll::Ready(false) - } else { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - None => Poll::Ready(true), - } - } -} diff --git a/src/stream/any.rs b/src/stream/any.rs deleted file mode 100644 index aac0ec7..0000000 --- a/src/stream/any.rs +++ /dev/null @@ -1,55 +0,0 @@ -use crate::stream::Stream; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Future for the [`any`](super::StreamExt::any) method. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct AnyFuture<'a, St: ?Sized, F> { - stream: &'a mut St, - f: F, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -impl<'a, St: ?Sized, F> AnyFuture<'a, St, F> { - pub(super) fn new(stream: &'a mut St, f: F) -> Self { - Self { - stream, - f, - _pin: PhantomPinned, - } - } -} - -impl<St, F> Future for AnyFuture<'_, St, F> -where - St: ?Sized + Stream + Unpin, - F: FnMut(St::Item) -> bool, -{ - type Output = bool; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = self.project(); - let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx)); - - match next { - Some(v) => { - if (me.f)(v) { - Poll::Ready(true) - } else { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - None => Poll::Ready(false), - } - } -} diff --git a/src/stream/chain.rs b/src/stream/chain.rs deleted file mode 100644 index 6124c91..0000000 --- a/src/stream/chain.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::stream::{Fuse, Stream}; - -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream returned by the [`chain`](super::StreamExt::chain) method. - pub struct Chain<T, U> { - #[pin] - a: Fuse<T>, - #[pin] - b: U, - } -} - -impl<T, U> Chain<T, U> { - pub(super) fn new(a: T, b: U) -> Chain<T, U> - where - T: Stream, - U: Stream, - { - Chain { a: Fuse::new(a), b } - } -} - -impl<T, U> Stream for Chain<T, U> -where - T: Stream, - U: Stream<Item = T::Item>, -{ - type Item = T::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> { - use Poll::Ready; - - let me = self.project(); - - if let Some(v) = ready!(me.a.poll_next(cx)) { - return Ready(Some(v)); - } - - me.b.poll_next(cx) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) - } -} diff --git a/src/stream/collect.rs b/src/stream/collect.rs deleted file mode 100644 index 1aafc30..0000000 --- a/src/stream/collect.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::stream::Stream; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::mem; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -// Do not export this struct until `FromStream` can be unsealed. -pin_project! { - /// Future returned by the [`collect`](super::StreamExt::collect) method. - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[derive(Debug)] - pub struct Collect<T, U> - where - T: Stream, - U: FromStream<T::Item>, - { - #[pin] - stream: T, - collection: U::InternalCollection, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -/// Convert from a [`Stream`](crate::stream::Stream). -/// -/// This trait is not intended to be used directly. Instead, call -/// [`StreamExt::collect()`](super::StreamExt::collect). -/// -/// # Implementing -/// -/// Currently, this trait may not be implemented by third parties. The trait is -/// sealed in order to make changes in the future. Stabilization is pending -/// enhancements to the Rust language. -pub trait FromStream<T>: sealed::FromStreamPriv<T> {} - -impl<T, U> Collect<T, U> -where - T: Stream, - U: FromStream<T::Item>, -{ - pub(super) fn new(stream: T) -> Collect<T, U> { - let (lower, upper) = stream.size_hint(); - let collection = U::initialize(sealed::Internal, lower, upper); - - Collect { - stream, - collection, - _pin: PhantomPinned, - } - } -} - -impl<T, U> Future for Collect<T, U> -where - T: Stream, - U: FromStream<T::Item>, -{ - type Output = U; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<U> { - use Poll::Ready; - - loop { - let mut me = self.as_mut().project(); - - let item = match ready!(me.stream.poll_next(cx)) { - Some(item) => item, - None => { - return Ready(U::finalize(sealed::Internal, &mut me.collection)); - } - }; - - if !U::extend(sealed::Internal, &mut me.collection, item) { - return Ready(U::finalize(sealed::Internal, &mut me.collection)); - } - } - } -} - -// ===== FromStream implementations - -impl FromStream<()> for () {} - -impl sealed::FromStreamPriv<()> for () { - type InternalCollection = (); - - fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) {} - - fn extend(_: sealed::Internal, _collection: &mut (), _item: ()) -> bool { - true - } - - fn finalize(_: sealed::Internal, _collection: &mut ()) {} -} - -impl<T: AsRef<str>> FromStream<T> for String {} - -impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String { - type InternalCollection = String; - - fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) -> String { - String::new() - } - - fn extend(_: sealed::Internal, collection: &mut String, item: T) -> bool { - collection.push_str(item.as_ref()); - true - } - - fn finalize(_: sealed::Internal, collection: &mut String) -> String { - mem::replace(collection, String::new()) - } -} - -impl<T> FromStream<T> for Vec<T> {} - -impl<T> sealed::FromStreamPriv<T> for Vec<T> { - type InternalCollection = Vec<T>; - - fn initialize(_: sealed::Internal, lower: usize, _upper: Option<usize>) -> Vec<T> { - Vec::with_capacity(lower) - } - - fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool { - collection.push(item); - true - } - - fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Vec<T> { - mem::replace(collection, vec![]) - } -} - -impl<T> FromStream<T> for Box<[T]> {} - -impl<T> sealed::FromStreamPriv<T> for Box<[T]> { - type InternalCollection = Vec<T>; - - fn initialize(_: sealed::Internal, lower: usize, upper: Option<usize>) -> Vec<T> { - <Vec<T> as sealed::FromStreamPriv<T>>::initialize(sealed::Internal, lower, upper) - } - - fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool { - <Vec<T> as sealed::FromStreamPriv<T>>::extend(sealed::Internal, collection, item) - } - - fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Box<[T]> { - <Vec<T> as sealed::FromStreamPriv<T>>::finalize(sealed::Internal, collection) - .into_boxed_slice() - } -} - -impl<T, U, E> FromStream<Result<T, E>> for Result<U, E> where U: FromStream<T> {} - -impl<T, U, E> sealed::FromStreamPriv<Result<T, E>> for Result<U, E> -where - U: FromStream<T>, -{ - type InternalCollection = Result<U::InternalCollection, E>; - - fn initialize( - _: sealed::Internal, - lower: usize, - upper: Option<usize>, - ) -> Result<U::InternalCollection, E> { - Ok(U::initialize(sealed::Internal, lower, upper)) - } - - fn extend( - _: sealed::Internal, - collection: &mut Self::InternalCollection, - item: Result<T, E>, - ) -> bool { - assert!(collection.is_ok()); - match item { - Ok(item) => { - let collection = collection.as_mut().ok().expect("invalid state"); - U::extend(sealed::Internal, collection, item) - } - Err(err) => { - *collection = Err(err); - false - } - } - } - - fn finalize(_: sealed::Internal, collection: &mut Self::InternalCollection) -> Result<U, E> { - if let Ok(collection) = collection.as_mut() { - Ok(U::finalize(sealed::Internal, collection)) - } else { - let res = mem::replace(collection, Ok(U::initialize(sealed::Internal, 0, Some(0)))); - - if let Err(err) = res { - Err(err) - } else { - unreachable!(); - } - } - } -} - -pub(crate) mod sealed { - #[doc(hidden)] - pub trait FromStreamPriv<T> { - /// Intermediate type used during collection process - /// - /// The name of this type is internal and cannot be relied upon. - type InternalCollection; - - /// Initialize the collection - fn initialize( - internal: Internal, - lower: usize, - upper: Option<usize>, - ) -> Self::InternalCollection; - - /// Extend the collection with the received item - /// - /// Return `true` to continue streaming, `false` complete collection. - fn extend(internal: Internal, collection: &mut Self::InternalCollection, item: T) -> bool; - - /// Finalize collection into target type. - fn finalize(internal: Internal, collection: &mut Self::InternalCollection) -> Self; - } - - #[allow(missing_debug_implementations)] - pub struct Internal; -} diff --git a/src/stream/empty.rs b/src/stream/empty.rs deleted file mode 100644 index 2f56ac6..0000000 --- a/src/stream/empty.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::stream::Stream; - -use core::marker::PhantomData; -use core::pin::Pin; -use core::task::{Context, Poll}; - -/// Stream for the [`empty`](fn@empty) function. -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Empty<T>(PhantomData<T>); - -impl<T> Unpin for Empty<T> {} -unsafe impl<T> Send for Empty<T> {} -unsafe impl<T> Sync for Empty<T> {} - -/// Creates a stream that yields nothing. -/// -/// The returned stream is immediately ready and returns `None`. Use -/// [`stream::pending()`](super::pending()) to obtain a stream that is never -/// ready. -/// -/// # Examples -/// -/// Basic usage: -/// -/// ``` -/// use tokio::stream::{self, StreamExt}; -/// -/// #[tokio::main] -/// async fn main() { -/// let mut none = stream::empty::<i32>(); -/// -/// assert_eq!(None, none.next().await); -/// } -/// ``` -pub const fn empty<T>() -> Empty<T> { - Empty(PhantomData) -} - -impl<T> Stream for Empty<T> { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<T>> { - Poll::Ready(None) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (0, Some(0)) - } -} diff --git a/src/stream/filter.rs b/src/stream/filter.rs deleted file mode 100644 index 799630b..0000000 --- a/src/stream/filter.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream returned by the [`filter`](super::StreamExt::filter) method. - #[must_use = "streams do nothing unless polled"] - pub struct Filter<St, F> { - #[pin] - stream: St, - f: F, - } -} - -impl<St, F> fmt::Debug for Filter<St, F> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Filter") - .field("stream", &self.stream) - .finish() - } -} - -impl<St, F> Filter<St, F> { - pub(super) fn new(stream: St, f: F) -> Self { - Self { stream, f } - } -} - -impl<St, F> Stream for Filter<St, F> -where - St: Stream, - F: FnMut(&St::Item) -> bool, -{ - type Item = St::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St::Item>> { - loop { - match ready!(self.as_mut().project().stream.poll_next(cx)) { - Some(e) => { - if (self.as_mut().project().f)(&e) { - return Poll::Ready(Some(e)); - } - } - None => return Poll::Ready(None), - } - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (0, self.stream.size_hint().1) // can't know a lower bound, due to the predicate - } -} diff --git a/src/stream/filter_map.rs b/src/stream/filter_map.rs deleted file mode 100644 index 8dc05a5..0000000 --- a/src/stream/filter_map.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream returned by the [`filter_map`](super::StreamExt::filter_map) method. - #[must_use = "streams do nothing unless polled"] - pub struct FilterMap<St, F> { - #[pin] - stream: St, - f: F, - } -} - -impl<St, F> fmt::Debug for FilterMap<St, F> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FilterMap") - .field("stream", &self.stream) - .finish() - } -} - -impl<St, F> FilterMap<St, F> { - pub(super) fn new(stream: St, f: F) -> Self { - Self { stream, f } - } -} - -impl<St, F, T> Stream for FilterMap<St, F> -where - St: Stream, - F: FnMut(St::Item) -> Option<T>, -{ - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - loop { - match ready!(self.as_mut().project().stream.poll_next(cx)) { - Some(e) => { - if let Some(e) = (self.as_mut().project().f)(e) { - return Poll::Ready(Some(e)); - } - } - None => return Poll::Ready(None), - } - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (0, self.stream.size_hint().1) // can't know a lower bound, due to the predicate - } -} diff --git a/src/stream/fold.rs b/src/stream/fold.rs deleted file mode 100644 index 5cf2bfa..0000000 --- a/src/stream/fold.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::stream::Stream; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Future returned by the [`fold`](super::StreamExt::fold) method. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct FoldFuture<St, B, F> { - #[pin] - stream: St, - acc: Option<B>, - f: F, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -impl<St, B, F> FoldFuture<St, B, F> { - pub(super) fn new(stream: St, init: B, f: F) -> Self { - Self { - stream, - acc: Some(init), - f, - _pin: PhantomPinned, - } - } -} - -impl<St, B, F> Future for FoldFuture<St, B, F> -where - St: Stream, - F: FnMut(B, St::Item) -> B, -{ - type Output = B; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let mut me = self.project(); - loop { - let next = ready!(me.stream.as_mut().poll_next(cx)); - - match next { - Some(v) => { - let old = me.acc.take().unwrap(); - let new = (me.f)(old, v); - *me.acc = Some(new); - } - None => return Poll::Ready(me.acc.take().unwrap()), - } - } - } -} diff --git a/src/stream/fuse.rs b/src/stream/fuse.rs deleted file mode 100644 index 6c9e02d..0000000 --- a/src/stream/fuse.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::stream::Stream; - -use pin_project_lite::pin_project; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pin_project! { - /// Stream returned by [`fuse()`][super::StreamExt::fuse]. - #[derive(Debug)] - pub struct Fuse<T> { - #[pin] - stream: Option<T>, - } -} - -impl<T> Fuse<T> -where - T: Stream, -{ - pub(crate) fn new(stream: T) -> Fuse<T> { - Fuse { - stream: Some(stream), - } - } -} - -impl<T> Stream for Fuse<T> -where - T: Stream, -{ - type Item = T::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> { - let res = match Option::as_pin_mut(self.as_mut().project().stream) { - Some(stream) => ready!(stream.poll_next(cx)), - None => return Poll::Ready(None), - }; - - if res.is_none() { - // Do not poll the stream anymore - self.as_mut().project().stream.set(None); - } - - Poll::Ready(res) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - match self.stream { - Some(ref stream) => stream.size_hint(), - None => (0, Some(0)), - } - } -} diff --git a/src/stream/iter.rs b/src/stream/iter.rs deleted file mode 100644 index bc0388a..0000000 --- a/src/stream/iter.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::stream::Stream; - -use core::pin::Pin; -use core::task::{Context, Poll}; - -/// Stream for the [`iter`](fn@iter) function. -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Iter<I> { - iter: I, -} - -impl<I> Unpin for Iter<I> {} - -/// Converts an `Iterator` into a `Stream` which is always ready -/// to yield the next value. -/// -/// Iterators in Rust don't express the ability to block, so this adapter -/// simply always calls `iter.next()` and returns that. -/// -/// ``` -/// # async fn dox() { -/// use tokio::stream::{self, StreamExt}; -/// -/// let mut stream = stream::iter(vec![17, 19]); -/// -/// assert_eq!(stream.next().await, Some(17)); -/// assert_eq!(stream.next().await, Some(19)); -/// assert_eq!(stream.next().await, None); -/// # } -/// ``` -pub fn iter<I>(i: I) -> Iter<I::IntoIter> -where - I: IntoIterator, -{ - Iter { - iter: i.into_iter(), - } -} - -impl<I> Stream for Iter<I> -where - I: Iterator, -{ - type Item = I::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<I::Item>> { - let coop = ready!(crate::coop::poll_proceed(cx)); - coop.made_progress(); - Poll::Ready(self.iter.next()) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - self.iter.size_hint() - } -} diff --git a/src/stream/map.rs b/src/stream/map.rs deleted file mode 100644 index dfac5a2..0000000 --- a/src/stream/map.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream for the [`map`](super::StreamExt::map) method. - #[must_use = "streams do nothing unless polled"] - pub struct Map<St, F> { - #[pin] - stream: St, - f: F, - } -} - -impl<St, F> fmt::Debug for Map<St, F> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Map").field("stream", &self.stream).finish() - } -} - -impl<St, F> Map<St, F> { - pub(super) fn new(stream: St, f: F) -> Self { - Map { stream, f } - } -} - -impl<St, F, T> Stream for Map<St, F> -where - St: Stream, - F: FnMut(St::Item) -> T, -{ - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.as_mut() - .project() - .stream - .poll_next(cx) - .map(|opt| opt.map(|x| (self.as_mut().project().f)(x))) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - self.stream.size_hint() - } -} diff --git a/src/stream/merge.rs b/src/stream/merge.rs deleted file mode 100644 index 50ba518..0000000 --- a/src/stream/merge.rs +++ /dev/null @@ -1,89 +0,0 @@ -use crate::stream::{Fuse, Stream}; - -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream returned by the [`merge`](super::StreamExt::merge) method. - pub struct Merge<T, U> { - #[pin] - a: Fuse<T>, - #[pin] - b: Fuse<U>, - // When `true`, poll `a` first, otherwise, `poll` b`. - a_first: bool, - } -} - -impl<T, U> Merge<T, U> { - pub(super) fn new(a: T, b: U) -> Merge<T, U> - where - T: Stream, - U: Stream, - { - Merge { - a: Fuse::new(a), - b: Fuse::new(b), - a_first: true, - } - } -} - -impl<T, U> Stream for Merge<T, U> -where - T: Stream, - U: Stream<Item = T::Item>, -{ - type Item = T::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> { - let me = self.project(); - let a_first = *me.a_first; - - // Toggle the flag - *me.a_first = !a_first; - - if a_first { - poll_next(me.a, me.b, cx) - } else { - poll_next(me.b, me.a, cx) - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - super::merge_size_hints(self.a.size_hint(), self.b.size_hint()) - } -} - -fn poll_next<T, U>( - first: Pin<&mut T>, - second: Pin<&mut U>, - cx: &mut Context<'_>, -) -> Poll<Option<T::Item>> -where - T: Stream, - U: Stream<Item = T::Item>, -{ - use Poll::*; - - let mut done = true; - - match first.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(val)), - Ready(None) => {} - Pending => done = false, - } - - match second.poll_next(cx) { - Ready(Some(val)) => return Ready(Some(val)), - Ready(None) => {} - Pending => done = false, - } - - if done { - Ready(None) - } else { - Pending - } -} diff --git a/src/stream/mod.rs b/src/stream/mod.rs deleted file mode 100644 index 81afe7a..0000000 --- a/src/stream/mod.rs +++ /dev/null @@ -1,971 +0,0 @@ -//! Stream utilities for Tokio. -//! -//! A `Stream` is an asynchronous sequence of values. It can be thought of as -//! an asynchronous version of the standard library's `Iterator` trait. -//! -//! This module provides helpers to work with them. For examples of usage and a more in-depth -//! description of streams you can also refer to the [streams -//! tutorial](https://tokio.rs/tokio/tutorial/streams) on the tokio website. -//! -//! # Iterating over a Stream -//! -//! Due to similarities with the standard library's `Iterator` trait, some new -//! users may assume that they can use `for in` syntax to iterate over a -//! `Stream`, but this is unfortunately not possible. Instead, you can use a -//! `while let` loop as follows: -//! -//! ```rust -//! use tokio::stream::{self, StreamExt}; -//! -//! #[tokio::main] -//! async fn main() { -//! let mut stream = stream::iter(vec![0, 1, 2]); -//! -//! while let Some(value) = stream.next().await { -//! println!("Got {}", value); -//! } -//! } -//! ``` -//! -//! # Returning a Stream from a function -//! -//! A common way to stream values from a function is to pass in the sender -//! half of a channel and use the receiver as the stream. This requires awaiting -//! both futures to ensure progress is made. Another alternative is the -//! [async-stream] crate, which contains macros that provide a `yield` keyword -//! and allow you to return an `impl Stream`. -//! -//! [async-stream]: https://docs.rs/async-stream -//! -//! # Conversion to and from AsyncRead/AsyncWrite -//! -//! It is often desirable to convert a `Stream` into an [`AsyncRead`], -//! especially when dealing with plaintext formats streamed over the network. -//! The opposite conversion from an [`AsyncRead`] into a `Stream` is also -//! another commonly required feature. To enable these conversions, -//! [`tokio-util`] provides the [`StreamReader`] and [`ReaderStream`] -//! types when the io feature is enabled. -//! -//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html -//! [`tokio::io`]: crate::io -//! [`AsyncRead`]: crate::io::AsyncRead -//! [`AsyncWrite`]: crate::io::AsyncWrite -//! [`ReaderStream`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.ReaderStream.html -//! [`StreamReader`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.StreamReader.html - -mod all; -use all::AllFuture; - -mod any; -use any::AnyFuture; - -mod chain; -use chain::Chain; - -mod collect; -use collect::Collect; -pub use collect::FromStream; - -mod empty; -pub use empty::{empty, Empty}; - -mod filter; -use filter::Filter; - -mod filter_map; -use filter_map::FilterMap; - -mod fold; -use fold::FoldFuture; - -mod fuse; -use fuse::Fuse; - -mod iter; -pub use iter::{iter, Iter}; - -mod map; -use map::Map; - -mod merge; -use merge::Merge; - -mod next; -use next::Next; - -mod once; -pub use once::{once, Once}; - -mod pending; -pub use pending::{pending, Pending}; - -mod stream_map; -pub use stream_map::StreamMap; - -mod skip; -use skip::Skip; - -mod skip_while; -use skip_while::SkipWhile; - -mod try_next; -use try_next::TryNext; - -mod take; -use take::Take; - -mod take_while; -use take_while::TakeWhile; - -cfg_time! { - mod timeout; - use timeout::Timeout; - use crate::time::Duration; - mod throttle; - use crate::stream::throttle::{throttle, Throttle}; -} - -#[doc(no_inline)] -pub use futures_core::Stream; - -/// An extension trait for the [`Stream`] trait that provides a variety of -/// convenient combinator functions. -/// -/// Be aware that the `Stream` trait in Tokio is a re-export of the trait found -/// in the [futures] crate, however both Tokio and futures provide separate -/// `StreamExt` utility traits, and some utilities are only available on one of -/// these traits. Click [here][futures-StreamExt] to see the other `StreamExt` -/// trait in the futures crate. -/// -/// If you need utilities from both `StreamExt` traits, you should prefer to -/// import one of them, and use the other through the fully qualified call -/// syntax. For example: -/// ``` -/// // import one of the traits: -/// use futures::stream::StreamExt; -/// # #[tokio::main(flavor = "current_thread")] -/// # async fn main() { -/// -/// let a = tokio::stream::iter(vec![1, 3, 5]); -/// let b = tokio::stream::iter(vec![2, 4, 6]); -/// -/// // use the fully qualified call syntax for the other trait: -/// let merged = tokio::stream::StreamExt::merge(a, b); -/// -/// // use normal call notation for futures::stream::StreamExt::collect -/// let output: Vec<_> = merged.collect().await; -/// assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); -/// # } -/// ``` -/// -/// [`Stream`]: crate::stream::Stream -/// [futures]: https://docs.rs/futures -/// [futures-StreamExt]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html -pub trait StreamExt: Stream { - /// Consumes and returns the next value in the stream or `None` if the - /// stream is finished. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn next(&mut self) -> Option<Self::Item>; - /// ``` - /// - /// Note that because `next` doesn't take ownership over the stream, - /// the [`Stream`] type must be [`Unpin`]. If you want to use `next` with a - /// [`!Unpin`](Unpin) stream, you'll first have to pin the stream. This can - /// be done by boxing the stream using [`Box::pin`] or - /// pinning it to the stack using the `pin_mut!` macro from the `pin_utils` - /// crate. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let mut stream = stream::iter(1..=3); - /// - /// assert_eq!(stream.next().await, Some(1)); - /// assert_eq!(stream.next().await, Some(2)); - /// assert_eq!(stream.next().await, Some(3)); - /// assert_eq!(stream.next().await, None); - /// # } - /// ``` - fn next(&mut self) -> Next<'_, Self> - where - Self: Unpin, - { - Next::new(self) - } - - /// Consumes and returns the next item in the stream. If an error is - /// encountered before the next item, the error is returned instead. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn try_next(&mut self) -> Result<Option<T>, E>; - /// ``` - /// - /// This is similar to the [`next`](StreamExt::next) combinator, - /// but returns a [`Result<Option<T>, E>`](Result) rather than - /// an [`Option<Result<T, E>>`](Option), making for easy use - /// with the [`?`](std::ops::Try) operator. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let mut stream = stream::iter(vec![Ok(1), Ok(2), Err("nope")]); - /// - /// assert_eq!(stream.try_next().await, Ok(Some(1))); - /// assert_eq!(stream.try_next().await, Ok(Some(2))); - /// assert_eq!(stream.try_next().await, Err("nope")); - /// # } - /// ``` - fn try_next<T, E>(&mut self) -> TryNext<'_, Self> - where - Self: Stream<Item = Result<T, E>> + Unpin, - { - TryNext::new(self) - } - - /// Maps this stream's items to a different type, returning a new stream of - /// the resulting type. - /// - /// The provided closure is executed over all elements of this stream as - /// they are made available. It is executed inline with calls to - /// [`poll_next`](Stream::poll_next). - /// - /// Note that this function consumes the stream passed into it and returns a - /// wrapped version of it, similar to the existing `map` methods in the - /// standard library. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let stream = stream::iter(1..=3); - /// let mut stream = stream.map(|x| x + 3); - /// - /// assert_eq!(stream.next().await, Some(4)); - /// assert_eq!(stream.next().await, Some(5)); - /// assert_eq!(stream.next().await, Some(6)); - /// # } - /// ``` - fn map<T, F>(self, f: F) -> Map<Self, F> - where - F: FnMut(Self::Item) -> T, - Self: Sized, - { - Map::new(self, f) - } - - /// Combine two streams into one by interleaving the output of both as it - /// is produced. - /// - /// Values are produced from the merged stream in the order they arrive from - /// the two source streams. If both source streams provide values - /// simultaneously, the merge stream alternates between them. This provides - /// some level of fairness. You should not chain calls to `merge`, as this - /// will break the fairness of the merging. - /// - /// The merged stream completes once **both** source streams complete. When - /// one source stream completes before the other, the merge stream - /// exclusively polls the remaining stream. - /// - /// For merging multiple streams, consider using [`StreamMap`] instead. - /// - /// [`StreamMap`]: crate::stream::StreamMap - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::StreamExt; - /// use tokio::sync::mpsc; - /// use tokio::time; - /// - /// use std::time::Duration; - /// - /// # /* - /// #[tokio::main] - /// # */ - /// # #[tokio::main(flavor = "current_thread")] - /// async fn main() { - /// # time::pause(); - /// let (tx1, rx1) = mpsc::channel(10); - /// let (tx2, rx2) = mpsc::channel(10); - /// - /// let mut rx = rx1.merge(rx2); - /// - /// tokio::spawn(async move { - /// // Send some values immediately - /// tx1.send(1).await.unwrap(); - /// tx1.send(2).await.unwrap(); - /// - /// // Let the other task send values - /// time::sleep(Duration::from_millis(20)).await; - /// - /// tx1.send(4).await.unwrap(); - /// }); - /// - /// tokio::spawn(async move { - /// // Wait for the first task to send values - /// time::sleep(Duration::from_millis(5)).await; - /// - /// tx2.send(3).await.unwrap(); - /// - /// time::sleep(Duration::from_millis(25)).await; - /// - /// // Send the final value - /// tx2.send(5).await.unwrap(); - /// }); - /// - /// assert_eq!(1, rx.next().await.unwrap()); - /// assert_eq!(2, rx.next().await.unwrap()); - /// assert_eq!(3, rx.next().await.unwrap()); - /// assert_eq!(4, rx.next().await.unwrap()); - /// assert_eq!(5, rx.next().await.unwrap()); - /// - /// // The merged stream is consumed - /// assert!(rx.next().await.is_none()); - /// } - /// ``` - fn merge<U>(self, other: U) -> Merge<Self, U> - where - U: Stream<Item = Self::Item>, - Self: Sized, - { - Merge::new(self, other) - } - - /// Filters the values produced by this stream according to the provided - /// predicate. - /// - /// As values of this stream are made available, the provided predicate `f` - /// will be run against them. If the predicate - /// resolves to `true`, then the stream will yield the value, but if the - /// predicate resolves to `false`, then the value - /// will be discarded and the next value will be produced. - /// - /// Note that this function consumes the stream passed into it and returns a - /// wrapped version of it, similar to [`Iterator::filter`] method in the - /// standard library. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let stream = stream::iter(1..=8); - /// let mut evens = stream.filter(|x| x % 2 == 0); - /// - /// assert_eq!(Some(2), evens.next().await); - /// assert_eq!(Some(4), evens.next().await); - /// assert_eq!(Some(6), evens.next().await); - /// assert_eq!(Some(8), evens.next().await); - /// assert_eq!(None, evens.next().await); - /// # } - /// ``` - fn filter<F>(self, f: F) -> Filter<Self, F> - where - F: FnMut(&Self::Item) -> bool, - Self: Sized, - { - Filter::new(self, f) - } - - /// Filters the values produced by this stream while simultaneously mapping - /// them to a different type according to the provided closure. - /// - /// As values of this stream are made available, the provided function will - /// be run on them. If the predicate `f` resolves to - /// [`Some(item)`](Some) then the stream will yield the value `item`, but if - /// it resolves to [`None`], then the value will be skipped. - /// - /// Note that this function consumes the stream passed into it and returns a - /// wrapped version of it, similar to [`Iterator::filter_map`] method in the - /// standard library. - /// - /// # Examples - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let stream = stream::iter(1..=8); - /// let mut evens = stream.filter_map(|x| { - /// if x % 2 == 0 { Some(x + 1) } else { None } - /// }); - /// - /// assert_eq!(Some(3), evens.next().await); - /// assert_eq!(Some(5), evens.next().await); - /// assert_eq!(Some(7), evens.next().await); - /// assert_eq!(Some(9), evens.next().await); - /// assert_eq!(None, evens.next().await); - /// # } - /// ``` - fn filter_map<T, F>(self, f: F) -> FilterMap<Self, F> - where - F: FnMut(Self::Item) -> Option<T>, - Self: Sized, - { - FilterMap::new(self, f) - } - - /// Creates a stream which ends after the first `None`. - /// - /// After a stream returns `None`, behavior is undefined. Future calls to - /// `poll_next` may or may not return `Some(T)` again or they may panic. - /// `fuse()` adapts a stream, ensuring that after `None` is given, it will - /// return `None` forever. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{Stream, StreamExt}; - /// - /// use std::pin::Pin; - /// use std::task::{Context, Poll}; - /// - /// // a stream which alternates between Some and None - /// struct Alternate { - /// state: i32, - /// } - /// - /// impl Stream for Alternate { - /// type Item = i32; - /// - /// fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<i32>> { - /// let val = self.state; - /// self.state = self.state + 1; - /// - /// // if it's even, Some(i32), else None - /// if val % 2 == 0 { - /// Poll::Ready(Some(val)) - /// } else { - /// Poll::Ready(None) - /// } - /// } - /// } - /// - /// #[tokio::main] - /// async fn main() { - /// let mut stream = Alternate { state: 0 }; - /// - /// // the stream goes back and forth - /// assert_eq!(stream.next().await, Some(0)); - /// assert_eq!(stream.next().await, None); - /// assert_eq!(stream.next().await, Some(2)); - /// assert_eq!(stream.next().await, None); - /// - /// // however, once it is fused - /// let mut stream = stream.fuse(); - /// - /// assert_eq!(stream.next().await, Some(4)); - /// assert_eq!(stream.next().await, None); - /// - /// // it will always return `None` after the first time. - /// assert_eq!(stream.next().await, None); - /// assert_eq!(stream.next().await, None); - /// assert_eq!(stream.next().await, None); - /// } - /// ``` - fn fuse(self) -> Fuse<Self> - where - Self: Sized, - { - Fuse::new(self) - } - - /// Creates a new stream of at most `n` items of the underlying stream. - /// - /// Once `n` items have been yielded from this stream then it will always - /// return that the stream is done. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let mut stream = stream::iter(1..=10).take(3); - /// - /// assert_eq!(Some(1), stream.next().await); - /// assert_eq!(Some(2), stream.next().await); - /// assert_eq!(Some(3), stream.next().await); - /// assert_eq!(None, stream.next().await); - /// # } - /// ``` - fn take(self, n: usize) -> Take<Self> - where - Self: Sized, - { - Take::new(self, n) - } - - /// Take elements from this stream while the provided predicate - /// resolves to `true`. - /// - /// This function, like `Iterator::take_while`, will take elements from the - /// stream until the predicate `f` resolves to `false`. Once one element - /// returns false it will always return that the stream is done. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let mut stream = stream::iter(1..=10).take_while(|x| *x <= 3); - /// - /// assert_eq!(Some(1), stream.next().await); - /// assert_eq!(Some(2), stream.next().await); - /// assert_eq!(Some(3), stream.next().await); - /// assert_eq!(None, stream.next().await); - /// # } - /// ``` - fn take_while<F>(self, f: F) -> TakeWhile<Self, F> - where - F: FnMut(&Self::Item) -> bool, - Self: Sized, - { - TakeWhile::new(self, f) - } - - /// Creates a new stream that will skip the `n` first items of the - /// underlying stream. - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let mut stream = stream::iter(1..=10).skip(7); - /// - /// assert_eq!(Some(8), stream.next().await); - /// assert_eq!(Some(9), stream.next().await); - /// assert_eq!(Some(10), stream.next().await); - /// assert_eq!(None, stream.next().await); - /// # } - /// ``` - fn skip(self, n: usize) -> Skip<Self> - where - Self: Sized, - { - Skip::new(self, n) - } - - /// Skip elements from the underlying stream while the provided predicate - /// resolves to `true`. - /// - /// This function, like [`Iterator::skip_while`], will ignore elemets from the - /// stream until the predicate `f` resolves to `false`. Once one element - /// returns false, the rest of the elements will be yielded. - /// - /// [`Iterator::skip_while`]: std::iter::Iterator::skip_while() - /// - /// # Examples - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// let mut stream = stream::iter(vec![1,2,3,4,1]).skip_while(|x| *x < 3); - /// - /// assert_eq!(Some(3), stream.next().await); - /// assert_eq!(Some(4), stream.next().await); - /// assert_eq!(Some(1), stream.next().await); - /// assert_eq!(None, stream.next().await); - /// # } - /// ``` - fn skip_while<F>(self, f: F) -> SkipWhile<Self, F> - where - F: FnMut(&Self::Item) -> bool, - Self: Sized, - { - SkipWhile::new(self, f) - } - - /// Tests if every element of the stream matches a predicate. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn all<F>(&mut self, f: F) -> bool; - /// ``` - /// - /// `all()` takes a closure that returns `true` or `false`. It applies - /// this closure to each element of the stream, and if they all return - /// `true`, then so does `all`. If any of them return `false`, it - /// returns `false`. An empty stream returns `true`. - /// - /// `all()` is short-circuiting; in other words, it will stop processing - /// as soon as it finds a `false`, given that no matter what else happens, - /// the result will also be `false`. - /// - /// An empty stream returns `true`. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let a = [1, 2, 3]; - /// - /// assert!(stream::iter(&a).all(|&x| x > 0).await); - /// - /// assert!(!stream::iter(&a).all(|&x| x > 2).await); - /// # } - /// ``` - /// - /// Stopping at the first `false`: - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let a = [1, 2, 3]; - /// - /// let mut iter = stream::iter(&a); - /// - /// assert!(!iter.all(|&x| x != 2).await); - /// - /// // we can still use `iter`, as there are more elements. - /// assert_eq!(iter.next().await, Some(&3)); - /// # } - /// ``` - fn all<F>(&mut self, f: F) -> AllFuture<'_, Self, F> - where - Self: Unpin, - F: FnMut(Self::Item) -> bool, - { - AllFuture::new(self, f) - } - - /// Tests if any element of the stream matches a predicate. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn any<F>(&mut self, f: F) -> bool; - /// ``` - /// - /// `any()` takes a closure that returns `true` or `false`. It applies - /// this closure to each element of the stream, and if any of them return - /// `true`, then so does `any()`. If they all return `false`, it - /// returns `false`. - /// - /// `any()` is short-circuiting; in other words, it will stop processing - /// as soon as it finds a `true`, given that no matter what else happens, - /// the result will also be `true`. - /// - /// An empty stream returns `false`. - /// - /// Basic usage: - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let a = [1, 2, 3]; - /// - /// assert!(stream::iter(&a).any(|&x| x > 0).await); - /// - /// assert!(!stream::iter(&a).any(|&x| x > 5).await); - /// # } - /// ``` - /// - /// Stopping at the first `true`: - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// - /// let a = [1, 2, 3]; - /// - /// let mut iter = stream::iter(&a); - /// - /// assert!(iter.any(|&x| x != 2).await); - /// - /// // we can still use `iter`, as there are more elements. - /// assert_eq!(iter.next().await, Some(&2)); - /// # } - /// ``` - fn any<F>(&mut self, f: F) -> AnyFuture<'_, Self, F> - where - Self: Unpin, - F: FnMut(Self::Item) -> bool, - { - AnyFuture::new(self, f) - } - - /// Combine two streams into one by first returning all values from the - /// first stream then all values from the second stream. - /// - /// As long as `self` still has values to emit, no values from `other` are - /// emitted, even if some are ready. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{self, StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// let one = stream::iter(vec![1, 2, 3]); - /// let two = stream::iter(vec![4, 5, 6]); - /// - /// let mut stream = one.chain(two); - /// - /// assert_eq!(stream.next().await, Some(1)); - /// assert_eq!(stream.next().await, Some(2)); - /// assert_eq!(stream.next().await, Some(3)); - /// assert_eq!(stream.next().await, Some(4)); - /// assert_eq!(stream.next().await, Some(5)); - /// assert_eq!(stream.next().await, Some(6)); - /// assert_eq!(stream.next().await, None); - /// } - /// ``` - fn chain<U>(self, other: U) -> Chain<Self, U> - where - U: Stream<Item = Self::Item>, - Self: Sized, - { - Chain::new(self, other) - } - - /// A combinator that applies a function to every element in a stream - /// producing a single, final value. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn fold<B, F>(self, init: B, f: F) -> B; - /// ``` - /// - /// # Examples - /// Basic usage: - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, *}; - /// - /// let s = stream::iter(vec![1u8, 2, 3]); - /// let sum = s.fold(0, |acc, x| acc + x).await; - /// - /// assert_eq!(sum, 6); - /// # } - /// ``` - fn fold<B, F>(self, init: B, f: F) -> FoldFuture<Self, B, F> - where - Self: Sized, - F: FnMut(B, Self::Item) -> B, - { - FoldFuture::new(self, init, f) - } - - /// Drain stream pushing all emitted values into a collection. - /// - /// Equivalent to: - /// - /// ```ignore - /// async fn collect<T>(self) -> T; - /// ``` - /// - /// `collect` streams all values, awaiting as needed. Values are pushed into - /// a collection. A number of different target collection types are - /// supported, including [`Vec`](std::vec::Vec), - /// [`String`](std::string::String), and [`Bytes`](bytes::Bytes). - /// - /// # `Result` - /// - /// `collect()` can also be used with streams of type `Result<T, E>` where - /// `T: FromStream<_>`. In this case, `collect()` will stream as long as - /// values yielded from the stream are `Ok(_)`. If `Err(_)` is encountered, - /// streaming is terminated and `collect()` returns the `Err`. - /// - /// # Notes - /// - /// `FromStream` is currently a sealed trait. Stabilization is pending - /// enhancements to the Rust language. - /// - /// # Examples - /// - /// Basic usage: - /// - /// ``` - /// use tokio::stream::{self, StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// let doubled: Vec<i32> = - /// stream::iter(vec![1, 2, 3]) - /// .map(|x| x * 2) - /// .collect() - /// .await; - /// - /// assert_eq!(vec![2, 4, 6], doubled); - /// } - /// ``` - /// - /// Collecting a stream of `Result` values - /// - /// ``` - /// use tokio::stream::{self, StreamExt}; - /// - /// #[tokio::main] - /// async fn main() { - /// // A stream containing only `Ok` values will be collected - /// let values: Result<Vec<i32>, &str> = - /// stream::iter(vec![Ok(1), Ok(2), Ok(3)]) - /// .collect() - /// .await; - /// - /// assert_eq!(Ok(vec![1, 2, 3]), values); - /// - /// // A stream containing `Err` values will return the first error. - /// let results = vec![Ok(1), Err("no"), Ok(2), Ok(3), Err("nein")]; - /// - /// let values: Result<Vec<i32>, &str> = - /// stream::iter(results) - /// .collect() - /// .await; - /// - /// assert_eq!(Err("no"), values); - /// } - /// ``` - fn collect<T>(self) -> Collect<Self, T> - where - T: FromStream<Self::Item>, - Self: Sized, - { - Collect::new(self) - } - - /// Applies a per-item timeout to the passed stream. - /// - /// `timeout()` takes a `Duration` that represents the maximum amount of - /// time each element of the stream has to complete before timing out. - /// - /// If the wrapped stream yields a value before the deadline is reached, the - /// value is returned. Otherwise, an error is returned. The caller may decide - /// to continue consuming the stream and will eventually get the next source - /// stream value once it becomes available. - /// - /// # Notes - /// - /// This function consumes the stream passed into it and returns a - /// wrapped version of it. - /// - /// Polling the returned stream will continue to poll the inner stream even - /// if one or more items time out. - /// - /// # Examples - /// - /// Suppose we have a stream `int_stream` that yields 3 numbers (1, 2, 3): - /// - /// ``` - /// # #[tokio::main] - /// # async fn main() { - /// use tokio::stream::{self, StreamExt}; - /// use std::time::Duration; - /// # let int_stream = stream::iter(1..=3); - /// - /// let mut int_stream = int_stream.timeout(Duration::from_secs(1)); - /// - /// // When no items time out, we get the 3 elements in succession: - /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); - /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); - /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); - /// assert_eq!(int_stream.try_next().await, Ok(None)); - /// - /// // If the second item times out, we get an error and continue polling the stream: - /// # let mut int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); - /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); - /// assert!(int_stream.try_next().await.is_err()); - /// assert_eq!(int_stream.try_next().await, Ok(Some(2))); - /// assert_eq!(int_stream.try_next().await, Ok(Some(3))); - /// assert_eq!(int_stream.try_next().await, Ok(None)); - /// - /// // If we want to stop consuming the source stream the first time an - /// // element times out, we can use the `take_while` operator: - /// # let int_stream = stream::iter(vec![Ok(1), Err(()), Ok(2), Ok(3)]); - /// let mut int_stream = int_stream.take_while(Result::is_ok); - /// - /// assert_eq!(int_stream.try_next().await, Ok(Some(1))); - /// assert_eq!(int_stream.try_next().await, Ok(None)); - /// # } - /// ``` - #[cfg(all(feature = "time"))] - #[cfg_attr(docsrs, doc(cfg(feature = "time")))] - fn timeout(self, duration: Duration) -> Timeout<Self> - where - Self: Sized, - { - Timeout::new(self, duration) - } - - /// Slows down a stream by enforcing a delay between items. - /// - /// # Example - /// - /// Create a throttled stream. - /// ```rust,no_run - /// use std::time::Duration; - /// use tokio::stream::StreamExt; - /// - /// # async fn dox() { - /// let mut item_stream = futures::stream::repeat("one").throttle(Duration::from_secs(2)); - /// - /// loop { - /// // The string will be produced at most every 2 seconds - /// println!("{:?}", item_stream.next().await); - /// } - /// # } - /// ``` - #[cfg(all(feature = "time"))] - #[cfg_attr(docsrs, doc(cfg(feature = "time")))] - fn throttle(self, duration: Duration) -> Throttle<Self> - where - Self: Sized, - { - throttle(duration, self) - } -} - -impl<St: ?Sized> StreamExt for St where St: Stream {} - -/// Merge the size hints from two streams. -fn merge_size_hints( - (left_low, left_high): (usize, Option<usize>), - (right_low, right_hign): (usize, Option<usize>), -) -> (usize, Option<usize>) { - let low = left_low.saturating_add(right_low); - let high = match (left_high, right_hign) { - (Some(h1), Some(h2)) => h1.checked_add(h2), - _ => None, - }; - (low, high) -} diff --git a/src/stream/next.rs b/src/stream/next.rs deleted file mode 100644 index d9b1f92..0000000 --- a/src/stream/next.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::stream::Stream; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Future for the [`next`](super::StreamExt::next) method. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct Next<'a, St: ?Sized> { - stream: &'a mut St, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -impl<'a, St: ?Sized> Next<'a, St> { - pub(super) fn new(stream: &'a mut St) -> Self { - Next { - stream, - _pin: PhantomPinned, - } - } -} - -impl<St: ?Sized + Stream + Unpin> Future for Next<'_, St> { - type Output = Option<St::Item>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = self.project(); - Pin::new(me.stream).poll_next(cx) - } -} diff --git a/src/stream/once.rs b/src/stream/once.rs deleted file mode 100644 index 7fe204c..0000000 --- a/src/stream/once.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::stream::{self, Iter, Stream}; - -use core::option; -use core::pin::Pin; -use core::task::{Context, Poll}; - -/// Stream for the [`once`](fn@once) function. -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Once<T> { - iter: Iter<option::IntoIter<T>>, -} - -impl<I> Unpin for Once<I> {} - -/// Creates a stream that emits an element exactly once. -/// -/// The returned stream is immediately ready and emits the provided value once. -/// -/// # Examples -/// -/// ``` -/// use tokio::stream::{self, StreamExt}; -/// -/// #[tokio::main] -/// async fn main() { -/// // one is the loneliest number -/// let mut one = stream::once(1); -/// -/// assert_eq!(Some(1), one.next().await); -/// -/// // just one, that's all we get -/// assert_eq!(None, one.next().await); -/// } -/// ``` -pub fn once<T>(value: T) -> Once<T> { - Once { - iter: stream::iter(Some(value).into_iter()), - } -} - -impl<T> Stream for Once<T> { - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - Pin::new(&mut self.iter).poll_next(cx) - } - - fn size_hint(&self) -> (usize, Option<usize>) { - self.iter.size_hint() - } -} diff --git a/src/stream/pending.rs b/src/stream/pending.rs deleted file mode 100644 index 21224c3..0000000 --- a/src/stream/pending.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::stream::Stream; - -use core::marker::PhantomData; -use core::pin::Pin; -use core::task::{Context, Poll}; - -/// Stream for the [`pending`](fn@pending) function. -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Pending<T>(PhantomData<T>); - -impl<T> Unpin for Pending<T> {} -unsafe impl<T> Send for Pending<T> {} -unsafe impl<T> Sync for Pending<T> {} - -/// Creates a stream that is never ready -/// -/// The returned stream is never ready. Attempting to call -/// [`next()`](crate::stream::StreamExt::next) will never complete. Use -/// [`stream::empty()`](super::empty()) to obtain a stream that is is -/// immediately empty but returns no values. -/// -/// # Examples -/// -/// Basic usage: -/// -/// ```no_run -/// use tokio::stream::{self, StreamExt}; -/// -/// #[tokio::main] -/// async fn main() { -/// let mut never = stream::pending::<i32>(); -/// -/// // This will never complete -/// never.next().await; -/// -/// unreachable!(); -/// } -/// ``` -pub const fn pending<T>() -> Pending<T> { - Pending(PhantomData) -} - -impl<T> Stream for Pending<T> { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<T>> { - Poll::Pending - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (0, None) - } -} diff --git a/src/stream/skip.rs b/src/stream/skip.rs deleted file mode 100644 index 39540cc..0000000 --- a/src/stream/skip.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream for the [`skip`](super::StreamExt::skip) method. - #[must_use = "streams do nothing unless polled"] - pub struct Skip<St> { - #[pin] - stream: St, - remaining: usize, - } -} - -impl<St> fmt::Debug for Skip<St> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Skip") - .field("stream", &self.stream) - .finish() - } -} - -impl<St> Skip<St> { - pub(super) fn new(stream: St, remaining: usize) -> Self { - Self { stream, remaining } - } -} - -impl<St> Stream for Skip<St> -where - St: Stream, -{ - type Item = St::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - loop { - match ready!(self.as_mut().project().stream.poll_next(cx)) { - Some(e) => { - if self.remaining == 0 { - return Poll::Ready(Some(e)); - } - *self.as_mut().project().remaining -= 1; - } - None => return Poll::Ready(None), - } - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - let (lower, upper) = self.stream.size_hint(); - - let lower = lower.saturating_sub(self.remaining); - let upper = upper.map(|x| x.saturating_sub(self.remaining)); - - (lower, upper) - } -} diff --git a/src/stream/skip_while.rs b/src/stream/skip_while.rs deleted file mode 100644 index 4e05007..0000000 --- a/src/stream/skip_while.rs +++ /dev/null @@ -1,73 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream for the [`skip_while`](super::StreamExt::skip_while) method. - #[must_use = "streams do nothing unless polled"] - pub struct SkipWhile<St, F> { - #[pin] - stream: St, - predicate: Option<F>, - } -} - -impl<St, F> fmt::Debug for SkipWhile<St, F> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SkipWhile") - .field("stream", &self.stream) - .finish() - } -} - -impl<St, F> SkipWhile<St, F> { - pub(super) fn new(stream: St, predicate: F) -> Self { - Self { - stream, - predicate: Some(predicate), - } - } -} - -impl<St, F> Stream for SkipWhile<St, F> -where - St: Stream, - F: FnMut(&St::Item) -> bool, -{ - type Item = St::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let mut this = self.project(); - if let Some(predicate) = this.predicate { - loop { - match ready!(this.stream.as_mut().poll_next(cx)) { - Some(item) => { - if !(predicate)(&item) { - *this.predicate = None; - return Poll::Ready(Some(item)); - } - } - None => return Poll::Ready(None), - } - } - } else { - this.stream.poll_next(cx) - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - let (lower, upper) = self.stream.size_hint(); - - if self.predicate.is_some() { - return (0, upper); - } - - (lower, upper) - } -} diff --git a/src/stream/stream_map.rs b/src/stream/stream_map.rs deleted file mode 100644 index 9fed3c1..0000000 --- a/src/stream/stream_map.rs +++ /dev/null @@ -1,555 +0,0 @@ -use crate::stream::Stream; - -use std::borrow::Borrow; -use std::hash::Hash; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// Combine many streams into one, indexing each source stream with a unique -/// key. -/// -/// `StreamMap` is similar to [`StreamExt::merge`] in that it combines source -/// streams into a single merged stream that yields values in the order that -/// they arrive from the source streams. However, `StreamMap` has a lot more -/// flexibility in usage patterns. -/// -/// `StreamMap` can: -/// -/// * Merge an arbitrary number of streams. -/// * Track which source stream the value was received from. -/// * Handle inserting and removing streams from the set of managed streams at -/// any point during iteration. -/// -/// All source streams held by `StreamMap` are indexed using a key. This key is -/// included with the value when a source stream yields a value. The key is also -/// used to remove the stream from the `StreamMap` before the stream has -/// completed streaming. -/// -/// # `Unpin` -/// -/// Because the `StreamMap` API moves streams during runtime, both streams and -/// keys must be `Unpin`. In order to insert a `!Unpin` stream into a -/// `StreamMap`, use [`pin!`] to pin the stream to the stack or [`Box::pin`] to -/// pin the stream in the heap. -/// -/// # Implementation -/// -/// `StreamMap` is backed by a `Vec<(K, V)>`. There is no guarantee that this -/// internal implementation detail will persist in future versions, but it is -/// important to know the runtime implications. In general, `StreamMap` works -/// best with a "smallish" number of streams as all entries are scanned on -/// insert, remove, and polling. In cases where a large number of streams need -/// to be merged, it may be advisable to use tasks sending values on a shared -/// [`mpsc`] channel. -/// -/// [`StreamExt::merge`]: crate::stream::StreamExt::merge -/// [`mpsc`]: crate::sync::mpsc -/// [`pin!`]: macro@pin -/// [`Box::pin`]: std::boxed::Box::pin -/// -/// # Examples -/// -/// Merging two streams, then remove them after receiving the first value -/// -/// ``` -/// use tokio::stream::{StreamExt, StreamMap}; -/// use tokio::sync::mpsc; -/// -/// #[tokio::main] -/// async fn main() { -/// let (tx1, rx1) = mpsc::channel(10); -/// let (tx2, rx2) = mpsc::channel(10); -/// -/// tokio::spawn(async move { -/// tx1.send(1).await.unwrap(); -/// -/// // This value will never be received. The send may or may not return -/// // `Err` depending on if the remote end closed first or not. -/// let _ = tx1.send(2).await; -/// }); -/// -/// tokio::spawn(async move { -/// tx2.send(3).await.unwrap(); -/// let _ = tx2.send(4).await; -/// }); -/// -/// let mut map = StreamMap::new(); -/// -/// // Insert both streams -/// map.insert("one", rx1); -/// map.insert("two", rx2); -/// -/// // Read twice -/// for _ in 0..2 { -/// let (key, val) = map.next().await.unwrap(); -/// -/// if key == "one" { -/// assert_eq!(val, 1); -/// } else { -/// assert_eq!(val, 3); -/// } -/// -/// // Remove the stream to prevent reading the next value -/// map.remove(key); -/// } -/// } -/// ``` -/// -/// This example models a read-only client to a chat system with channels. The -/// client sends commands to join and leave channels. `StreamMap` is used to -/// manage active channel subscriptions. -/// -/// For simplicity, messages are displayed with `println!`, but they could be -/// sent to the client over a socket. -/// -/// ```no_run -/// use tokio::stream::{Stream, StreamExt, StreamMap}; -/// -/// enum Command { -/// Join(String), -/// Leave(String), -/// } -/// -/// fn commands() -> impl Stream<Item = Command> { -/// // Streams in user commands by parsing `stdin`. -/// # tokio::stream::pending() -/// } -/// -/// // Join a channel, returns a stream of messages received on the channel. -/// fn join(channel: &str) -> impl Stream<Item = String> + Unpin { -/// // left as an exercise to the reader -/// # tokio::stream::pending() -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let mut channels = StreamMap::new(); -/// -/// // Input commands (join / leave channels). -/// let cmds = commands(); -/// tokio::pin!(cmds); -/// -/// loop { -/// tokio::select! { -/// Some(cmd) = cmds.next() => { -/// match cmd { -/// Command::Join(chan) => { -/// // Join the channel and add it to the `channels` -/// // stream map -/// let msgs = join(&chan); -/// channels.insert(chan, msgs); -/// } -/// Command::Leave(chan) => { -/// channels.remove(&chan); -/// } -/// } -/// } -/// Some((chan, msg)) = channels.next() => { -/// // Received a message, display it on stdout with the channel -/// // it originated from. -/// println!("{}: {}", chan, msg); -/// } -/// // Both the `commands` stream and the `channels` stream are -/// // complete. There is no more work to do, so leave the loop. -/// else => break, -/// } -/// } -/// } -/// ``` -#[derive(Debug)] -pub struct StreamMap<K, V> { - /// Streams stored in the map - entries: Vec<(K, V)>, -} - -impl<K, V> StreamMap<K, V> { - /// An iterator visiting all key-value pairs in arbitrary order. - /// - /// The iterator element type is &'a (K, V). - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// map.insert("a", pending::<i32>()); - /// map.insert("b", pending()); - /// map.insert("c", pending()); - /// - /// for (key, stream) in map.iter() { - /// println!("({}, {:?})", key, stream); - /// } - /// ``` - pub fn iter(&self) -> impl Iterator<Item = &(K, V)> { - self.entries.iter() - } - - /// An iterator visiting all key-value pairs mutably in arbitrary order. - /// - /// The iterator element type is &'a mut (K, V). - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// map.insert("a", pending::<i32>()); - /// map.insert("b", pending()); - /// map.insert("c", pending()); - /// - /// for (key, stream) in map.iter_mut() { - /// println!("({}, {:?})", key, stream); - /// } - /// ``` - pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (K, V)> { - self.entries.iter_mut() - } - - /// Creates an empty `StreamMap`. - /// - /// The stream map is initially created with a capacity of `0`, so it will - /// not allocate until it is first inserted into. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, Pending}; - /// - /// let map: StreamMap<&str, Pending<()>> = StreamMap::new(); - /// ``` - pub fn new() -> StreamMap<K, V> { - StreamMap { entries: vec![] } - } - - /// Creates an empty `StreamMap` with the specified capacity. - /// - /// The stream map will be able to hold at least `capacity` elements without - /// reallocating. If `capacity` is 0, the stream map will not allocate. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, Pending}; - /// - /// let map: StreamMap<&str, Pending<()>> = StreamMap::with_capacity(10); - /// ``` - pub fn with_capacity(capacity: usize) -> StreamMap<K, V> { - StreamMap { - entries: Vec::with_capacity(capacity), - } - } - - /// Returns an iterator visiting all keys in arbitrary order. - /// - /// The iterator element type is &'a K. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// map.insert("a", pending::<i32>()); - /// map.insert("b", pending()); - /// map.insert("c", pending()); - /// - /// for key in map.keys() { - /// println!("{}", key); - /// } - /// ``` - pub fn keys(&self) -> impl Iterator<Item = &K> { - self.iter().map(|(k, _)| k) - } - - /// An iterator visiting all values in arbitrary order. - /// - /// The iterator element type is &'a V. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// map.insert("a", pending::<i32>()); - /// map.insert("b", pending()); - /// map.insert("c", pending()); - /// - /// for stream in map.values() { - /// println!("{:?}", stream); - /// } - /// ``` - pub fn values(&self) -> impl Iterator<Item = &V> { - self.iter().map(|(_, v)| v) - } - - /// An iterator visiting all values mutably in arbitrary order. - /// - /// The iterator element type is &'a mut V. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// map.insert("a", pending::<i32>()); - /// map.insert("b", pending()); - /// map.insert("c", pending()); - /// - /// for stream in map.values_mut() { - /// println!("{:?}", stream); - /// } - /// ``` - pub fn values_mut(&mut self) -> impl Iterator<Item = &mut V> { - self.iter_mut().map(|(_, v)| v) - } - - /// Returns the number of streams the map can hold without reallocating. - /// - /// This number is a lower bound; the `StreamMap` might be able to hold - /// more, but is guaranteed to be able to hold at least this many. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, Pending}; - /// - /// let map: StreamMap<i32, Pending<()>> = StreamMap::with_capacity(100); - /// assert!(map.capacity() >= 100); - /// ``` - pub fn capacity(&self) -> usize { - self.entries.capacity() - } - - /// Returns the number of streams in the map. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut a = StreamMap::new(); - /// assert_eq!(a.len(), 0); - /// a.insert(1, pending::<i32>()); - /// assert_eq!(a.len(), 1); - /// ``` - pub fn len(&self) -> usize { - self.entries.len() - } - - /// Returns `true` if the map contains no elements. - /// - /// # Examples - /// - /// ``` - /// use std::collections::HashMap; - /// - /// let mut a = HashMap::new(); - /// assert!(a.is_empty()); - /// a.insert(1, "a"); - /// assert!(!a.is_empty()); - /// ``` - pub fn is_empty(&self) -> bool { - self.entries.is_empty() - } - - /// Clears the map, removing all key-stream pairs. Keeps the allocated - /// memory for reuse. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut a = StreamMap::new(); - /// a.insert(1, pending::<i32>()); - /// a.clear(); - /// assert!(a.is_empty()); - /// ``` - pub fn clear(&mut self) { - self.entries.clear(); - } - - /// Insert a key-stream pair into the map. - /// - /// If the map did not have this key present, `None` is returned. - /// - /// If the map did have this key present, the new `stream` replaces the old - /// one and the old stream is returned. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// - /// assert!(map.insert(37, pending::<i32>()).is_none()); - /// assert!(!map.is_empty()); - /// - /// map.insert(37, pending()); - /// assert!(map.insert(37, pending()).is_some()); - /// ``` - pub fn insert(&mut self, k: K, stream: V) -> Option<V> - where - K: Hash + Eq, - { - let ret = self.remove(&k); - self.entries.push((k, stream)); - - ret - } - - /// Removes a key from the map, returning the stream at the key if the key was previously in the map. - /// - /// The key may be any borrowed form of the map's key type, but `Hash` and - /// `Eq` on the borrowed form must match those for the key type. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// map.insert(1, pending::<i32>()); - /// assert!(map.remove(&1).is_some()); - /// assert!(map.remove(&1).is_none()); - /// ``` - pub fn remove<Q: ?Sized>(&mut self, k: &Q) -> Option<V> - where - K: Borrow<Q>, - Q: Hash + Eq, - { - for i in 0..self.entries.len() { - if self.entries[i].0.borrow() == k { - return Some(self.entries.swap_remove(i).1); - } - } - - None - } - - /// Returns `true` if the map contains a stream for the specified key. - /// - /// The key may be any borrowed form of the map's key type, but `Hash` and - /// `Eq` on the borrowed form must match those for the key type. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::{StreamMap, pending}; - /// - /// let mut map = StreamMap::new(); - /// map.insert(1, pending::<i32>()); - /// assert_eq!(map.contains_key(&1), true); - /// assert_eq!(map.contains_key(&2), false); - /// ``` - pub fn contains_key<Q: ?Sized>(&self, k: &Q) -> bool - where - K: Borrow<Q>, - Q: Hash + Eq, - { - for i in 0..self.entries.len() { - if self.entries[i].0.borrow() == k { - return true; - } - } - - false - } -} - -impl<K, V> StreamMap<K, V> -where - K: Unpin, - V: Stream + Unpin, -{ - /// Polls the next value, includes the vec entry index - fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<Option<(usize, V::Item)>> { - use Poll::*; - - let start = crate::util::thread_rng_n(self.entries.len() as u32) as usize; - let mut idx = start; - - for _ in 0..self.entries.len() { - let (_, stream) = &mut self.entries[idx]; - - match Pin::new(stream).poll_next(cx) { - Ready(Some(val)) => return Ready(Some((idx, val))), - Ready(None) => { - // Remove the entry - self.entries.swap_remove(idx); - - // Check if this was the last entry, if so the cursor needs - // to wrap - if idx == self.entries.len() { - idx = 0; - } else if idx < start && start <= self.entries.len() { - // The stream being swapped into the current index has - // already been polled, so skip it. - idx = idx.wrapping_add(1) % self.entries.len(); - } - } - Pending => { - idx = idx.wrapping_add(1) % self.entries.len(); - } - } - } - - // If the map is empty, then the stream is complete. - if self.entries.is_empty() { - Ready(None) - } else { - Pending - } - } -} - -impl<K, V> Default for StreamMap<K, V> { - fn default() -> Self { - Self::new() - } -} - -impl<K, V> Stream for StreamMap<K, V> -where - K: Clone + Unpin, - V: Stream + Unpin, -{ - type Item = (K, V::Item); - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - if let Some((idx, val)) = ready!(self.poll_next_entry(cx)) { - let key = self.entries[idx].0.clone(); - Poll::Ready(Some((key, val))) - } else { - Poll::Ready(None) - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - let mut ret = (0, Some(0)); - - for (_, stream) in &self.entries { - let hint = stream.size_hint(); - - ret.0 += hint.0; - - match (ret.1, hint.1) { - (Some(a), Some(b)) => ret.1 = Some(a + b), - (Some(_), None) => ret.1 = None, - _ => {} - } - } - - ret - } -} diff --git a/src/stream/take.rs b/src/stream/take.rs deleted file mode 100644 index a92430b..0000000 --- a/src/stream/take.rs +++ /dev/null @@ -1,76 +0,0 @@ -use crate::stream::Stream; - -use core::cmp; -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream for the [`take`](super::StreamExt::take) method. - #[must_use = "streams do nothing unless polled"] - pub struct Take<St> { - #[pin] - stream: St, - remaining: usize, - } -} - -impl<St> fmt::Debug for Take<St> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Take") - .field("stream", &self.stream) - .finish() - } -} - -impl<St> Take<St> { - pub(super) fn new(stream: St, remaining: usize) -> Self { - Self { stream, remaining } - } -} - -impl<St> Stream for Take<St> -where - St: Stream, -{ - type Item = St::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - if *self.as_mut().project().remaining > 0 { - self.as_mut().project().stream.poll_next(cx).map(|ready| { - match &ready { - Some(_) => { - *self.as_mut().project().remaining -= 1; - } - None => { - *self.as_mut().project().remaining = 0; - } - } - ready - }) - } else { - Poll::Ready(None) - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - if self.remaining == 0 { - return (0, Some(0)); - } - - let (lower, upper) = self.stream.size_hint(); - - let lower = cmp::min(lower, self.remaining as usize); - - let upper = match upper { - Some(x) if x < self.remaining as usize => Some(x), - _ => Some(self.remaining as usize), - }; - - (lower, upper) - } -} diff --git a/src/stream/take_while.rs b/src/stream/take_while.rs deleted file mode 100644 index cf1e160..0000000 --- a/src/stream/take_while.rs +++ /dev/null @@ -1,79 +0,0 @@ -use crate::stream::Stream; - -use core::fmt; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Stream for the [`take_while`](super::StreamExt::take_while) method. - #[must_use = "streams do nothing unless polled"] - pub struct TakeWhile<St, F> { - #[pin] - stream: St, - predicate: F, - done: bool, - } -} - -impl<St, F> fmt::Debug for TakeWhile<St, F> -where - St: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TakeWhile") - .field("stream", &self.stream) - .field("done", &self.done) - .finish() - } -} - -impl<St, F> TakeWhile<St, F> { - pub(super) fn new(stream: St, predicate: F) -> Self { - Self { - stream, - predicate, - done: false, - } - } -} - -impl<St, F> Stream for TakeWhile<St, F> -where - St: Stream, - F: FnMut(&St::Item) -> bool, -{ - type Item = St::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - if !*self.as_mut().project().done { - self.as_mut().project().stream.poll_next(cx).map(|ready| { - let ready = ready.and_then(|item| { - if !(self.as_mut().project().predicate)(&item) { - None - } else { - Some(item) - } - }); - - if ready.is_none() { - *self.as_mut().project().done = true; - } - - ready - }) - } else { - Poll::Ready(None) - } - } - - fn size_hint(&self) -> (usize, Option<usize>) { - if self.done { - return (0, Some(0)); - } - - let (_, upper) = self.stream.size_hint(); - - (0, upper) - } -} diff --git a/src/stream/throttle.rs b/src/stream/throttle.rs deleted file mode 100644 index 8f4a256..0000000 --- a/src/stream/throttle.rs +++ /dev/null @@ -1,97 +0,0 @@ -//! Slow down a stream by enforcing a delay between items. - -use crate::stream::Stream; -use crate::time::{Duration, Instant, Sleep}; - -use std::future::Future; -use std::marker::Unpin; -use std::pin::Pin; -use std::task::{self, Poll}; - -use pin_project_lite::pin_project; - -pub(super) fn throttle<T>(duration: Duration, stream: T) -> Throttle<T> -where - T: Stream, -{ - let delay = if duration == Duration::from_millis(0) { - None - } else { - Some(Sleep::new_timeout(Instant::now() + duration, duration)) - }; - - Throttle { - delay, - duration, - has_delayed: true, - stream, - } -} - -pin_project! { - /// Stream for the [`throttle`](throttle) function. - #[derive(Debug)] - #[must_use = "streams do nothing unless polled"] - pub struct Throttle<T> { - // `None` when duration is zero. - delay: Option<Sleep>, - duration: Duration, - - // Set to true when `delay` has returned ready, but `stream` hasn't. - has_delayed: bool, - - // The stream to throttle - #[pin] - stream: T, - } -} - -// XXX: are these safe if `T: !Unpin`? -impl<T: Unpin> Throttle<T> { - /// Acquires a reference to the underlying stream that this combinator is - /// pulling from. - pub fn get_ref(&self) -> &T { - &self.stream - } - - /// Acquires a mutable reference to the underlying stream that this combinator - /// is pulling from. - /// - /// Note that care must be taken to avoid tampering with the state of the stream - /// which may otherwise confuse this combinator. - pub fn get_mut(&mut self) -> &mut T { - &mut self.stream - } - - /// Consumes this combinator, returning the underlying stream. - /// - /// Note that this may discard intermediate state of this combinator, so care - /// should be taken to avoid losing resources when this is called. - pub fn into_inner(self) -> T { - self.stream - } -} - -impl<T: Stream> Stream for Throttle<T> { - type Item = T::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> { - if !self.has_delayed && self.delay.is_some() { - ready!(Pin::new(self.as_mut().project().delay.as_mut().unwrap()).poll(cx)); - *self.as_mut().project().has_delayed = true; - } - - let value = ready!(self.as_mut().project().stream.poll_next(cx)); - - if value.is_some() { - let dur = self.duration; - if let Some(ref mut delay) = self.as_mut().project().delay { - delay.reset(Instant::now() + dur); - } - - *self.as_mut().project().has_delayed = false; - } - - Poll::Ready(value) - } -} diff --git a/src/stream/timeout.rs b/src/stream/timeout.rs deleted file mode 100644 index 669973f..0000000 --- a/src/stream/timeout.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::stream::{Fuse, Stream}; -use crate::time::{error::Elapsed, Instant, Sleep}; - -use core::future::Future; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; -use std::time::Duration; - -pin_project! { - /// Stream returned by the [`timeout`](super::StreamExt::timeout) method. - #[must_use = "streams do nothing unless polled"] - #[derive(Debug)] - pub struct Timeout<S> { - #[pin] - stream: Fuse<S>, - deadline: Sleep, - duration: Duration, - poll_deadline: bool, - } -} - -impl<S: Stream> Timeout<S> { - pub(super) fn new(stream: S, duration: Duration) -> Self { - let next = Instant::now() + duration; - let deadline = Sleep::new_timeout(next, duration); - - Timeout { - stream: Fuse::new(stream), - deadline, - duration, - poll_deadline: true, - } - } -} - -impl<S: Stream> Stream for Timeout<S> { - type Item = Result<S::Item, Elapsed>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - match self.as_mut().project().stream.poll_next(cx) { - Poll::Ready(v) => { - if v.is_some() { - let next = Instant::now() + self.duration; - self.as_mut().project().deadline.reset(next); - *self.as_mut().project().poll_deadline = true; - } - return Poll::Ready(v.map(Ok)); - } - Poll::Pending => {} - }; - - if self.poll_deadline { - ready!(Pin::new(self.as_mut().project().deadline).poll(cx)); - *self.as_mut().project().poll_deadline = false; - return Poll::Ready(Some(Err(Elapsed::new()))); - } - - Poll::Pending - } - - fn size_hint(&self) -> (usize, Option<usize>) { - self.stream.size_hint() - } -} diff --git a/src/stream/try_next.rs b/src/stream/try_next.rs deleted file mode 100644 index b21d279..0000000 --- a/src/stream/try_next.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::stream::{Next, Stream}; - -use core::future::Future; -use core::marker::PhantomPinned; -use core::pin::Pin; -use core::task::{Context, Poll}; -use pin_project_lite::pin_project; - -pin_project! { - /// Future for the [`try_next`](super::StreamExt::try_next) method. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct TryNext<'a, St: ?Sized> { - #[pin] - inner: Next<'a, St>, - // Make this future `!Unpin` for compatibility with async trait methods. - #[pin] - _pin: PhantomPinned, - } -} - -impl<'a, St: ?Sized> TryNext<'a, St> { - pub(super) fn new(stream: &'a mut St) -> Self { - Self { - inner: Next::new(stream), - _pin: PhantomPinned, - } - } -} - -impl<T, E, St: ?Sized + Stream<Item = Result<T, E>> + Unpin> Future for TryNext<'_, St> { - type Output = Result<Option<T>, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = self.project(); - me.inner.poll(cx).map(Option::transpose) - } -} diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index 0b50e4f..803f2a1 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -41,15 +41,28 @@ struct Waitlist { closed: bool, } -/// Error returned by `Semaphore::try_acquire`. -#[derive(Debug)] -pub(crate) enum TryAcquireError { +/// Error returned from the [`Semaphore::try_acquire`] function. +/// +/// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire +#[derive(Debug, PartialEq)] +pub enum TryAcquireError { + /// The semaphore has been [closed] and cannot issue new permits. + /// + /// [closed]: crate::sync::Semaphore::close Closed, + + /// The semaphore has no available permits. NoPermits, } -/// Error returned by `Semaphore::acquire`. +/// Error returned from the [`Semaphore::acquire`] function. +/// +/// An `acquire` operation can only fail if the semaphore has been +/// [closed]. +/// +/// [closed]: crate::sync::Semaphore::close +/// [`Semaphore::acquire`]: crate::sync::Semaphore::acquire #[derive(Debug)] -pub(crate) struct AcquireError(()); +pub struct AcquireError(()); pub(crate) struct Acquire<'a> { node: Waiter, @@ -164,8 +177,6 @@ impl Semaphore { /// Closes the semaphore. This prevents the semaphore from issuing new /// permits and notifies all pending waiters. - // This will be used once the bounded MPSC is updated to use the new - // semaphore implementation. pub(crate) fn close(&self) { let mut waiters = self.waiters.lock(); // If the semaphore's permits counter has enough permits for an @@ -253,9 +264,9 @@ impl Semaphore { } if rem > 0 && is_empty { - let permits = rem << Self::PERMIT_SHIFT; + let permits = rem; assert!( - permits < Self::MAX_PERMITS, + permits <= Self::MAX_PERMITS, "cannot add more than MAX_PERMITS permits ({})", Self::MAX_PERMITS ); diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index ee9aba0..1b94600 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -940,48 +940,6 @@ impl<T: Clone> Receiver<T> { let guard = self.recv_ref(None)?; guard.clone_value().ok_or(TryRecvError::Closed) } - - /// Convert the receiver into a `Stream`. - /// - /// The conversion allows using `Receiver` with APIs that require stream - /// values. - /// - /// # Examples - /// - /// ``` - /// use tokio::stream::StreamExt; - /// use tokio::sync::broadcast; - /// - /// #[tokio::main] - /// async fn main() { - /// let (tx, rx) = broadcast::channel(128); - /// - /// tokio::spawn(async move { - /// for i in 0..10_i32 { - /// tx.send(i).unwrap(); - /// } - /// }); - /// - /// // Streams must be pinned to iterate. - /// tokio::pin! { - /// let stream = rx - /// .into_stream() - /// .filter(Result::is_ok) - /// .map(Result::unwrap) - /// .filter(|v| v % 2 == 0) - /// .map(|v| v + 1); - /// } - /// - /// while let Some(i) = stream.next().await { - /// println!("{}", i); - /// } - /// } - /// ``` - #[cfg(feature = "stream")] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] - pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> { - Recv::new(Borrow(self)) - } } impl<T> Drop for Receiver<T> { @@ -1058,31 +1016,6 @@ where } } -cfg_stream! { - use futures_core::Stream; - - impl<R, T: Clone> Stream for Recv<R, T> - where - R: AsMut<Receiver<T>>, - T: Clone, - { - type Item = Result<T, RecvError>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - let (receiver, waiter) = self.project(); - - let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) { - Ok(value) => value, - Err(TryRecvError::Empty) => return Poll::Pending, - Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))), - Err(TryRecvError::Closed) => return Poll::Ready(None), - }; - - Poll::Ready(guard.clone_value().map(Ok)) - } - } -} - impl<R, T> Drop for Recv<R, T> where R: AsMut<Receiver<T>>, diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 57ae277..a953c66 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -359,7 +359,8 @@ //! let mut conf = rx.borrow().clone(); //! //! let mut op_start = Instant::now(); -//! let mut sleep = time::sleep_until(op_start + conf.timeout); +//! let sleep = time::sleep_until(op_start + conf.timeout); +//! tokio::pin!(sleep); //! //! loop { //! tokio::select! { @@ -371,14 +372,14 @@ //! op_start = Instant::now(); //! //! // Restart the timeout -//! sleep = time::sleep_until(op_start + conf.timeout); +//! sleep.set(time::sleep_until(op_start + conf.timeout)); //! } //! _ = rx.changed() => { //! conf = rx.borrow().clone(); //! //! // The configuration has been updated. Update the //! // `sleep` using the new `timeout` value. -//! sleep.reset(op_start + conf.timeout); +//! sleep.as_mut().reset(op_start + conf.timeout); //! } //! _ = &mut op => { //! // The operation completed! @@ -443,6 +444,8 @@ cfg_sync! { pub mod oneshot; pub(crate) mod batch_semaphore; + pub use batch_semaphore::{AcquireError, TryAcquireError}; + mod semaphore; pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit}; diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index 06b3717..2dae7e2 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -1,6 +1,9 @@ use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; +#[cfg(unix)] +#[cfg(any(feature = "signal", feature = "process"))] +use crate::sync::mpsc::error::TryRecvError; +use crate::sync::mpsc::error::{SendError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -8,7 +11,6 @@ cfg_time! { } use std::fmt; -#[cfg(any(feature = "signal", feature = "process", feature = "stream"))] use std::task::{Context, Poll}; /// Send values to the associated `Receiver`. @@ -144,11 +146,6 @@ impl<T> Receiver<T> { poll_fn(|cx| self.chan.recv(cx)).await } - #[cfg(any(feature = "signal", feature = "process"))] - pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.chan.recv(cx) - } - /// Blocking receive to call outside of asynchronous contexts. /// /// # Panics @@ -194,7 +191,9 @@ impl<T> Receiver<T> { /// /// Compared with recv, this function has two failure cases instead of /// one (one for disconnection, one for an empty buffer). - pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + #[cfg(unix)] + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { self.chan.try_recv() } @@ -238,6 +237,25 @@ impl<T> Receiver<T> { pub fn close(&mut self) { self.chan.close(); } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) + } } impl<T> fmt::Debug for Receiver<T> { @@ -250,16 +268,6 @@ impl<T> fmt::Debug for Receiver<T> { impl<T> Unpin for Receiver<T> {} -cfg_stream! { - impl<T> crate::stream::Stream for Receiver<T> { - type Item = T; - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.chan.recv(cx) - } - } -} - impl<T> Sender<T> { pub(crate) fn new(chan: chan::Tx<T, Semaphore>) -> Sender<T> { Sender { chan } diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs index a40f5c3..f34eb0f 100644 --- a/src/sync/mpsc/chan.rs +++ b/src/sync/mpsc/chan.rs @@ -2,7 +2,6 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; -use crate::sync::mpsc::error::TryRecvError; use crate::sync::mpsc::list; use crate::sync::notify::Notify; @@ -259,21 +258,29 @@ impl<T, S: Semaphore> Rx<T, S> { } }) } +} - /// Receives the next value without blocking - pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { - use super::block::Read::*; - self.inner.rx_fields.with_mut(|rx_fields_ptr| { - let rx_fields = unsafe { &mut *rx_fields_ptr }; - match rx_fields.list.pop(&self.inner.tx) { - Some(Value(value)) => { - self.inner.semaphore.add_permit(); - Ok(value) +feature! { + #![all(unix, any(feature = "signal", feature = "process"))] + + use crate::sync::mpsc::error::TryRecvError; + + impl<T, S: Semaphore> Rx<T, S> { + /// Receives the next value without blocking + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::block::Read::*; + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + match rx_fields.list.pop(&self.inner.tx) { + Some(Value(value)) => { + self.inner.semaphore.add_permit(); + Ok(value) + } + Some(Closed) => Err(TryRecvError::Closed), + None => Err(TryRecvError::Empty), } - Some(Closed) => Err(TryRecvError::Closed), - None => Err(TryRecvError::Empty), - } - }) + }) + } } } diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index 7705452..d23255b 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -67,32 +67,36 @@ impl Error for RecvError {} // ===== TryRecvError ===== -/// This enumeration is the list of the possible reasons that try_recv -/// could not return data when called. -#[derive(Debug, PartialEq)] -pub enum TryRecvError { - /// This channel is currently empty, but the Sender(s) have not yet - /// disconnected, so data may yet become available. - Empty, - /// The channel's sending half has been closed, and there will - /// never be any more data received on it. - Closed, -} +feature! { + #![all(unix, any(feature = "signal", feature = "process"))] + + /// This enumeration is the list of the possible reasons that try_recv + /// could not return data when called. + #[derive(Debug, PartialEq)] + pub(crate) enum TryRecvError { + /// This channel is currently empty, but the Sender(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The channel's sending half has been closed, and there will + /// never be any more data received on it. + Closed, + } -impl fmt::Display for TryRecvError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - fmt, - "{}", - match self { - TryRecvError::Empty => "channel empty", - TryRecvError::Closed => "channel closed", - } - ) + impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TryRecvError::Empty => "channel empty", + TryRecvError::Closed => "channel closed", + } + ) + } } -} -impl Error for TryRecvError {} + impl Error for TryRecvError {} +} cfg_time! { // ===== SendTimeoutError ===== diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs index a2bcf83..e7033f6 100644 --- a/src/sync/mpsc/mod.rs +++ b/src/sync/mpsc/mod.rs @@ -14,10 +14,8 @@ //! Similar to the `mpsc` channels provided by `std`, the channel constructor //! functions provide separate send and receive handles, [`Sender`] and //! [`Receiver`] for the bounded channel, [`UnboundedSender`] and -//! [`UnboundedReceiver`] for the unbounded channel. Both [`Receiver`] and -//! [`UnboundedReceiver`] implement [`Stream`] and allow a task to read -//! values out of the channel. If there is no message to read, the current task -//! will be notified when a new value is sent. [`Sender`] and +//! [`UnboundedReceiver`] for the unbounded channel. If there is no message to read, +//! the current task will be notified when a new value is sent. [`Sender`] and //! [`UnboundedSender`] allow sending values into the channel. If the bounded //! channel is at capacity, the send is rejected and the task will be notified //! when additional capacity is available. In other words, the channel provides @@ -62,7 +60,6 @@ //! //! [`Sender`]: crate::sync::mpsc::Sender //! [`Receiver`]: crate::sync::mpsc::Receiver -//! [`Stream`]: crate::stream::Stream //! [bounded-send]: crate::sync::mpsc::Sender::send() //! [bounded-recv]: crate::sync::mpsc::Receiver::recv() //! [blocking-send]: crate::sync::mpsc::Sender::blocking_send() diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index fe882d5..38953b8 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -1,6 +1,6 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{SendError, TryRecvError}; +use crate::sync::mpsc::error::SendError; use std::fmt; use std::task::{Context, Poll}; @@ -73,10 +73,6 @@ impl<T> UnboundedReceiver<T> { UnboundedReceiver { chan } } - fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.chan.recv(cx) - } - /// Receives the next value for this receiver. /// /// `None` is returned when all `Sender` halves have dropped, indicating @@ -122,19 +118,34 @@ impl<T> UnboundedReceiver<T> { poll_fn(|cx| self.poll_recv(cx)).await } - /// Attempts to return a pending value on this receiver without blocking. + /// Blocking receive to call outside of asynchronous contexts. /// - /// This method will never block the caller in order to wait for data to - /// become available. Instead, this will always return immediately with - /// a possible option of pending data on the channel. + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Examples /// - /// This is useful for a flavor of "optimistic check" before deciding to - /// block on a receiver. + /// ``` + /// use std::thread; + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel::<u8>(); /// - /// Compared with recv, this function has two failure cases instead of - /// one (one for disconnection, one for an empty buffer). - pub fn try_recv(&mut self) -> Result<T, TryRecvError> { - self.chan.try_recv() + /// let sync_code = thread::spawn(move || { + /// assert_eq!(Some(10), rx.blocking_recv()); + /// }); + /// + /// let _ = tx.send(10); + /// sync_code.join().unwrap(); + /// } + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_recv(&mut self) -> Option<T> { + crate::future::block_on(self.recv()) } /// Closes the receiving half of a channel, without dropping it. @@ -144,14 +155,24 @@ impl<T> UnboundedReceiver<T> { pub fn close(&mut self) { self.chan.close(); } -} -#[cfg(feature = "stream")] -impl<T> crate::stream::Stream for UnboundedReceiver<T> { - type Item = T; - - fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { - self.poll_recv(cx) + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not + /// closed. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages + /// sent before it was closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { + self.chan.recv(cx) } } diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 750765f..2e72cf7 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -237,114 +237,6 @@ pub struct RwLockWriteGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { - /// Make a new `RwLockWriteGuard` for a component of the locked data. - /// - /// This operation cannot fail as the `RwLockWriteGuard` passed in already - /// locked the data. - /// - /// This is an associated function that needs to be used as - /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of - /// the same name on the contents of the locked data. - /// - /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the - /// [`parking_lot` crate]. - /// - /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map - /// [`parking_lot` crate]: https://crates.io/crates/parking_lot - /// - /// # Examples - /// - /// ``` - /// use tokio::sync::{RwLock, RwLockWriteGuard}; - /// - /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] - /// struct Foo(u32); - /// - /// # #[tokio::main] - /// # async fn main() { - /// let lock = RwLock::new(Foo(1)); - /// - /// { - /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0); - /// *mapped = 2; - /// } - /// - /// assert_eq!(Foo(2), *lock.read().await); - /// # } - /// ``` - #[inline] - pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockWriteGuard<'a, U> - where - F: FnOnce(&mut T) -> &mut U, - { - let data = f(&mut *this) as *mut U; - let s = this.s; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); - RwLockWriteGuard { - s, - data, - marker: marker::PhantomData, - } - } - - /// Attempts to make a new [`RwLockWriteGuard`] for a component of - /// the locked data. The original guard is returned if the closure returns - /// `None`. - /// - /// This operation cannot fail as the `RwLockWriteGuard` passed in already - /// locked the data. - /// - /// This is an associated function that needs to be - /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with - /// methods of the same name on the contents of the locked data. - /// - /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from - /// the [`parking_lot` crate]. - /// - /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map - /// [`parking_lot` crate]: https://crates.io/crates/parking_lot - /// - /// # Examples - /// - /// ``` - /// use tokio::sync::{RwLock, RwLockWriteGuard}; - /// - /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] - /// struct Foo(u32); - /// - /// # #[tokio::main] - /// # async fn main() { - /// let lock = RwLock::new(Foo(1)); - /// - /// { - /// let guard = lock.write().await; - /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail"); - /// *guard = 2; - /// } - /// - /// assert_eq!(Foo(2), *lock.read().await); - /// # } - /// ``` - #[inline] - pub fn try_map<F, U: ?Sized>(mut this: Self, f: F) -> Result<RwLockWriteGuard<'a, U>, Self> - where - F: FnOnce(&mut T) -> Option<&mut U>, - { - let data = match f(&mut *this) { - Some(data) => data as *mut U, - None => return Err(this), - }; - let s = this.s; - // NB: Forget to avoid drop impl from being called. - mem::forget(this); - Ok(RwLockWriteGuard { - s, - data, - marker: marker::PhantomData, - }) - } - /// Atomically downgrades a write lock into a read lock without allowing /// any writers to take exclusive access of the lock in the meantime. /// diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index 2acccfa..5555bdf 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -1,4 +1,5 @@ use super::batch_semaphore as ll; // low level implementation +use super::{AcquireError, TryAcquireError}; use std::sync::Arc; /// Counting semaphore performing asynchronous permit acquisition. @@ -42,15 +43,6 @@ pub struct OwnedSemaphorePermit { permits: u32, } -/// Error returned from the [`Semaphore::try_acquire`] function. -/// -/// A `try_acquire` operation can only fail if the semaphore has no available -/// permits. -/// -/// [`Semaphore::try_acquire`]: Semaphore::try_acquire -#[derive(Debug)] -pub struct TryAcquireError(()); - #[test] #[cfg(not(loom))] fn bounds() { @@ -95,73 +87,148 @@ impl Semaphore { self.ll_sem.release(n); } - /// Acquires permit from the semaphore. - pub async fn acquire(&self) -> SemaphorePermit<'_> { - self.ll_sem.acquire(1).await.unwrap(); - SemaphorePermit { + /// Acquires a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permit. + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { + self.ll_sem.acquire(1).await?; + Ok(SemaphorePermit { sem: &self, permits: 1, - } + }) } - /// Acquires `n` permits from the semaphore - pub async fn acquire_many(&self, n: u32) -> SemaphorePermit<'_> { - self.ll_sem.acquire(n).await.unwrap(); - SemaphorePermit { + /// Acquires `n` permits from the semaphore. + /// + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`SemaphorePermit`] representing the + /// acquired permits. + /// + /// [`AcquireError`]: crate::sync::AcquireError + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit + pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { + self.ll_sem.acquire(n).await?; + Ok(SemaphorePermit { sem: &self, permits: n, - } + }) } /// Tries to acquire a permit from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, + /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(SemaphorePermit { sem: self, permits: 1, }), - Err(_) => Err(TryAcquireError(())), + Err(e) => Err(e), } } - /// Tries to acquire `n` permits from the semaphore. + /// Tries to acquire n permits from the semaphore. + /// + /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, + /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> { match self.ll_sem.try_acquire(n) { Ok(_) => Ok(SemaphorePermit { sem: self, permits: n, }), - Err(_) => Err(TryAcquireError(())), + Err(e) => Err(e), } } - /// Acquires permit from the semaphore. + /// Acquires a permit from the semaphore. /// /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. /// /// [`Arc`]: std::sync::Arc - pub async fn acquire_owned(self: Arc<Self>) -> OwnedSemaphorePermit { - self.ll_sem.acquire(1).await.unwrap(); - OwnedSemaphorePermit { + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, AcquireError> { + self.ll_sem.acquire(1).await?; + Ok(OwnedSemaphorePermit { sem: self, permits: 1, - } + }) } /// Tries to acquire a permit from the semaphore. /// - /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. /// /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> { match self.ll_sem.try_acquire(1) { Ok(_) => Ok(OwnedSemaphorePermit { sem: self, permits: 1, }), - Err(_) => Err(TryAcquireError(())), + Err(e) => Err(e), } } + + /// Closes the semaphore. + /// + /// This prevents the semaphore from issuing new permits and notifies all pending waiters. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// use std::sync::Arc; + /// use tokio::sync::TryAcquireError; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(1)); + /// let semaphore2 = semaphore.clone(); + /// + /// tokio::spawn(async move { + /// let permit = semaphore.acquire_many(2).await; + /// assert!(permit.is_err()); + /// println!("waiter received error"); + /// }); + /// + /// println!("closing semaphore"); + /// semaphore2.close(); + /// + /// // Cannot obtain more permits + /// assert_eq!(semaphore2.try_acquire().err(), Some(TryAcquireError::Closed)) + /// } + /// ``` + pub fn close(&self) { + self.ll_sem.close(); + } } impl<'a> SemaphorePermit<'a> { diff --git a/src/sync/watch.rs b/src/sync/watch.rs index b377ca7..6732d38 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -53,10 +53,10 @@ use crate::sync::Notify; +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst}; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::ops; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::{Relaxed, SeqCst}; -use std::sync::{Arc, RwLock, RwLockReadGuard}; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -241,19 +241,19 @@ impl<T> Receiver<T> { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { - // In order to avoid a race condition, we first request a notification, - // **then** check the current value's version. If a new version exists, - // the notification request is dropped. - let notified = self.shared.notify_rx.notified(); - - if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { - return ret; + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = self.shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious } - - notified.await; - - maybe_changed(&self.shared, &mut self.version) - .expect("[bug] failed to observe change after notificaton.") } } @@ -322,6 +322,25 @@ impl<T> Sender<T> { Ok(()) } + /// Returns a reference to the most recently sent value + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _) = watch::channel("hello"); + /// assert_eq!(*tx.borrow(), "hello"); + /// ``` + pub fn borrow(&self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + Ref { inner } + } + /// Checks if the channel has been closed. This happens when all receivers /// have dropped. /// @@ -390,3 +409,84 @@ impl<T> ops::Deref for Ref<'_, T> { self.inner.deref() } } + +#[cfg(all(test, loom))] +mod tests { + use futures::future::FutureExt; + use loom::thread; + + // test for https://github.com/tokio-rs/tokio/issues/3168 + #[test] + fn watch_spurious_wakeup() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + send.send(1).unwrap(); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let mut recv = recv_thread.join().unwrap(); + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + }); + + recv.changed().now_or_never(); + + send_thread.join().unwrap(); + }); + } + + #[test] + fn watch_borrow() { + loom::model(|| { + let (send, mut recv) = crate::sync::watch::channel(0i32); + + assert!(send.borrow().eq(&0)); + assert!(recv.borrow().eq(&0)); + + send.send(1).unwrap(); + assert!(send.borrow().eq(&1)); + + let send_thread = thread::spawn(move || { + send.send(2).unwrap(); + send + }); + + recv.changed().now_or_never(); + + let send = send_thread.join().unwrap(); + let recv_thread = thread::spawn(move || { + recv.changed().now_or_never(); + recv.changed().now_or_never(); + recv + }); + + send.send(3).unwrap(); + + let recv = recv_thread.join().unwrap(); + assert!(recv.borrow().eq(&3)); + assert!(send.borrow().eq(&3)); + + send.send(2).unwrap(); + + thread::spawn(move || { + assert!(recv.borrow().eq(&2)); + }); + assert!(send.borrow().eq(&2)); + }); + } +} diff --git a/src/task/blocking.rs b/src/task/blocking.rs index 36bc457..28bbcdb 100644 --- a/src/task/blocking.rs +++ b/src/task/blocking.rs @@ -51,64 +51,66 @@ cfg_rt_multi_thread! { } } -/// Runs the provided closure on a thread where blocking is acceptable. -/// -/// In general, issuing a blocking call or performing a lot of compute in a -/// future without yielding is problematic, as it may prevent the executor from -/// driving other futures forward. This function runs the provided closure on a -/// thread dedicated to blocking operations. See the [CPU-bound tasks and -/// blocking code][blocking] section for more information. -/// -/// Tokio will spawn more blocking threads when they are requested through this -/// function until the upper limit configured on the [`Builder`] is reached. -/// This limit is very large by default, because `spawn_blocking` is often used -/// for various kinds of IO operations that cannot be performed asynchronously. -/// When you run CPU-bound code using `spawn_blocking`, you should keep this -/// large upper limit in mind. When running many CPU-bound computations, a -/// semaphore or some other synchronization primitive should be used to limit -/// the number of computation executed in parallel. Specialized CPU-bound -/// executors, such as [rayon], may also be a good fit. -/// -/// This function is intended for non-async operations that eventually finish on -/// their own. If you want to spawn an ordinary thread, you should use -/// [`thread::spawn`] instead. -/// -/// Closures spawned using `spawn_blocking` cannot be cancelled. When you shut -/// down the executor, it will wait indefinitely for all blocking operations to -/// finish. You can use [`shutdown_timeout`] to stop waiting for them after a -/// certain timeout. Be aware that this will still not cancel the tasks — they -/// are simply allowed to keep running after the method returns. -/// -/// Note that if you are using the single threaded runtime, this function will -/// still spawn additional threads for blocking operations. The basic -/// scheduler's single thread is only used for asynchronous code. -/// -/// [`Builder`]: struct@crate::runtime::Builder -/// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code -/// [rayon]: https://docs.rs/rayon -/// [`thread::spawn`]: fn@std::thread::spawn -/// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout -/// -/// # Examples -/// -/// ``` -/// use tokio::task; -/// -/// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ -/// let res = task::spawn_blocking(move || { -/// // do some compute-heavy work or call synchronous code -/// "done computing" -/// }).await?; -/// -/// assert_eq!(res, "done computing"); -/// # Ok(()) -/// # } -/// ``` -#[cfg_attr(tokio_track_caller, track_caller)] -pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - crate::runtime::spawn_blocking(f) +cfg_rt! { + /// Runs the provided closure on a thread where blocking is acceptable. + /// + /// In general, issuing a blocking call or performing a lot of compute in a + /// future without yielding is problematic, as it may prevent the executor from + /// driving other futures forward. This function runs the provided closure on a + /// thread dedicated to blocking operations. See the [CPU-bound tasks and + /// blocking code][blocking] section for more information. + /// + /// Tokio will spawn more blocking threads when they are requested through this + /// function until the upper limit configured on the [`Builder`] is reached. + /// This limit is very large by default, because `spawn_blocking` is often used + /// for various kinds of IO operations that cannot be performed asynchronously. + /// When you run CPU-bound code using `spawn_blocking`, you should keep this + /// large upper limit in mind. When running many CPU-bound computations, a + /// semaphore or some other synchronization primitive should be used to limit + /// the number of computation executed in parallel. Specialized CPU-bound + /// executors, such as [rayon], may also be a good fit. + /// + /// This function is intended for non-async operations that eventually finish on + /// their own. If you want to spawn an ordinary thread, you should use + /// [`thread::spawn`] instead. + /// + /// Closures spawned using `spawn_blocking` cannot be cancelled. When you shut + /// down the executor, it will wait indefinitely for all blocking operations to + /// finish. You can use [`shutdown_timeout`] to stop waiting for them after a + /// certain timeout. Be aware that this will still not cancel the tasks — they + /// are simply allowed to keep running after the method returns. + /// + /// Note that if you are using the single threaded runtime, this function will + /// still spawn additional threads for blocking operations. The basic + /// scheduler's single thread is only used for asynchronous code. + /// + /// [`Builder`]: struct@crate::runtime::Builder + /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code + /// [rayon]: https://docs.rs/rayon + /// [`thread::spawn`]: fn@std::thread::spawn + /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// + /// # Examples + /// + /// ``` + /// use tokio::task; + /// + /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ + /// let res = task::spawn_blocking(move || { + /// // do some compute-heavy work or call synchronous code + /// "done computing" + /// }).await?; + /// + /// assert_eq!(res, "done computing"); + /// # Ok(()) + /// # } + /// ``` + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + crate::runtime::spawn_blocking(f) + } } diff --git a/src/task/task_local.rs b/src/task/task_local.rs index 1679ee3..bc2e54a 100644 --- a/src/task/task_local.rs +++ b/src/task/task_local.rs @@ -31,6 +31,7 @@ use std::{fmt, thread}; /// /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey #[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] macro_rules! task_local { // empty (base case for the recursion) () => {}; @@ -90,6 +91,7 @@ macro_rules! __task_local_inner { /// # } /// ``` /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub struct LocalKey<T: 'static> { #[doc(hidden)] pub inner: thread::LocalKey<RefCell<Option<T>>>, diff --git a/src/time/clock.rs b/src/time/clock.rs index fab7eca..a62fbe3 100644 --- a/src/time/clock.rs +++ b/src/time/clock.rs @@ -17,7 +17,7 @@ cfg_not_test_util! { } impl Clock { - pub(crate) fn new() -> Clock { + pub(crate) fn new(_enable_pausing: bool) -> Clock { Clock {} } @@ -59,6 +59,9 @@ cfg_test_util! { #[derive(Debug)] struct Inner { + /// True if the ability to pause time is enabled. + enable_pausing: bool, + /// Instant to use as the clock's base instant. base: std::time::Instant, @@ -69,14 +72,18 @@ cfg_test_util! { /// Pause time /// /// The current value of `Instant::now()` is saved and all subsequent calls - /// to `Instant::now()` until the timer wheel is checked again will return the saved value. - /// Once the timer wheel is checked, time will immediately advance to the next registered - /// `Sleep`. This is useful for running tests that depend on time. + /// to `Instant::now()` until the timer wheel is checked again will return + /// the saved value. Once the timer wheel is checked, time will immediately + /// advance to the next registered `Sleep`. This is useful for running tests + /// that depend on time. + /// + /// Pausing time requires the `current_thread` Tokio runtime. This is the + /// default runtime used by `#[tokio::test]` /// /// # Panics /// - /// Panics if time is already frozen or if called from outside of the Tokio - /// runtime. + /// Panics if time is already frozen or if called from outside of a + /// `current_thread` Tokio runtime. pub fn pause() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.pause(); @@ -142,11 +149,12 @@ cfg_test_util! { impl Clock { /// Return a new `Clock` instance that uses the current execution context's /// source of time. - pub(crate) fn new() -> Clock { + pub(crate) fn new(enable_pausing: bool) -> Clock { let now = std::time::Instant::now(); Clock { inner: Arc::new(Mutex::new(Inner { + enable_pausing, base: now, unfrozen: Some(now), })), @@ -156,6 +164,12 @@ cfg_test_util! { pub(crate) fn pause(&self) { let mut inner = self.inner.lock().unwrap(); + if !inner.enable_pausing { + drop(inner); // avoid poisoning the lock + panic!("`time::pause()` requires the `current_thread` Tokio runtime. \ + This is the default Runtime used by `#[tokio::test]."); + } + let elapsed = inner.unfrozen.as_ref().expect("time is already frozen").elapsed(); inner.base += elapsed; inner.unfrozen = None; diff --git a/src/time/driver/atomic_stack.rs b/src/time/driver/atomic_stack.rs deleted file mode 100644 index 5dcc472..0000000 --- a/src/time/driver/atomic_stack.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::time::driver::Entry; -use crate::time::error::Error; - -use std::ptr; -use std::sync::atomic::AtomicPtr; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::Arc; - -/// A stack of `Entry` nodes -#[derive(Debug)] -pub(crate) struct AtomicStack { - /// Stack head - head: AtomicPtr<Entry>, -} - -/// Entries that were removed from the stack -#[derive(Debug)] -pub(crate) struct AtomicStackEntries { - ptr: *mut Entry, -} - -/// Used to indicate that the timer has shutdown. -const SHUTDOWN: *mut Entry = 1 as *mut _; - -impl AtomicStack { - pub(crate) fn new() -> AtomicStack { - AtomicStack { - head: AtomicPtr::new(ptr::null_mut()), - } - } - - /// Pushes an entry onto the stack. - /// - /// Returns `true` if the entry was pushed, `false` if the entry is already - /// on the stack, `Err` if the timer is shutdown. - pub(crate) fn push(&self, entry: &Arc<Entry>) -> Result<bool, Error> { - // First, set the queued bit on the entry - let queued = entry.queued.fetch_or(true, SeqCst); - - if queued { - // Already queued, nothing more to do - return Ok(false); - } - - let ptr = Arc::into_raw(entry.clone()) as *mut _; - - let mut curr = self.head.load(SeqCst); - - loop { - if curr == SHUTDOWN { - // Don't leak the entry node - let _ = unsafe { Arc::from_raw(ptr) }; - - return Err(Error::shutdown()); - } - - // Update the `next` pointer. This is safe because setting the queued - // bit is a "lock" on this field. - unsafe { - *(entry.next_atomic.get()) = curr; - } - - let actual = self.head.compare_and_swap(curr, ptr, SeqCst); - - if actual == curr { - break; - } - - curr = actual; - } - - Ok(true) - } - - /// Takes all entries from the stack - pub(crate) fn take(&self) -> AtomicStackEntries { - let ptr = self.head.swap(ptr::null_mut(), SeqCst); - AtomicStackEntries { ptr } - } - - /// Drains all remaining nodes in the stack and prevent any new nodes from - /// being pushed onto the stack. - pub(crate) fn shutdown(&self) { - // Shutdown the processing queue - let ptr = self.head.swap(SHUTDOWN, SeqCst); - - // Let the drop fn of `AtomicStackEntries` handle draining the stack - drop(AtomicStackEntries { ptr }); - } -} - -// ===== impl AtomicStackEntries ===== - -impl Iterator for AtomicStackEntries { - type Item = Arc<Entry>; - - fn next(&mut self) -> Option<Self::Item> { - if self.ptr.is_null() || self.ptr == SHUTDOWN { - return None; - } - - // Convert the pointer to an `Arc<Entry>` - let entry = unsafe { Arc::from_raw(self.ptr) }; - - // Update `self.ptr` to point to the next element of the stack - self.ptr = unsafe { *entry.next_atomic.get() }; - - // Unset the queued flag - let res = entry.queued.fetch_and(false, SeqCst); - debug_assert!(res); - - // Return the entry - Some(entry) - } -} - -impl Drop for AtomicStackEntries { - fn drop(&mut self) { - for entry in self { - // Flag the entry as errored - entry.error(Error::shutdown()); - } - } -} diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs index b40cae7..bcad988 100644 --- a/src/time/driver/entry.rs +++ b/src/time/driver/entry.rs @@ -1,362 +1,697 @@ -use crate::loom::sync::atomic::AtomicU64; -use crate::sync::AtomicWaker; -use crate::time::driver::{Handle, Inner}; -use crate::time::{error::Error, Duration, Instant}; - -use std::cell::UnsafeCell; -use std::ptr; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::atomic::{AtomicBool, AtomicU8}; -use std::sync::{Arc, Weak}; -use std::task::{self, Poll}; -use std::u64; - -/// Internal state shared between a `Sleep` instance and the timer. -/// -/// This struct is used as a node in two intrusive data structures: -/// -/// * An atomic stack used to signal to the timer thread that the entry state -/// has changed. The timer thread will observe the entry on this stack and -/// perform any actions as necessary. -/// -/// * A doubly linked list used **only** by the timer thread. Each slot in the -/// timer wheel is a head pointer to the list of entries that must be -/// processed during that timer tick. -#[derive(Debug)] -pub(crate) struct Entry { - /// Only accessed from `Registration`. - time: CachePadded<UnsafeCell<Time>>, - - /// Timer internals. Using a weak pointer allows the timer to shutdown - /// without all `Sleep` instances having completed. - /// - /// When empty, it means that the entry has not yet been linked with a - /// timer instance. - inner: Weak<Inner>, - - /// Tracks the entry state. This value contains the following information: - /// - /// * The deadline at which the entry must be "fired". - /// * A flag indicating if the entry has already been fired. - /// * Whether or not the entry transitioned to the error state. - /// - /// When an `Entry` is created, `state` is initialized to the instant at - /// which the entry must be fired. When a timer is reset to a different - /// instant, this value is changed. - state: AtomicU64, +//! Timer state structures. +//! +//! This module contains the heart of the intrusive timer implementation, and as +//! such the structures inside are full of tricky concurrency and unsafe code. +//! +//! # Ground rules +//! +//! The heart of the timer implementation here is the `TimerShared` structure, +//! shared between the `TimerEntry` and the driver. Generally, we permit access +//! to `TimerShared` ONLY via either 1) a mutable reference to `TimerEntry` or +//! 2) a held driver lock. +//! +//! It follows from this that any changes made while holding BOTH 1 and 2 will +//! be reliably visible, regardless of ordering. This is because of the acq/rel +//! fences on the driver lock ensuring ordering with 2, and rust mutable +//! reference rules for 1 (a mutable reference to an object can't be passed +//! between threads without an acq/rel barrier, and same-thread we have local +//! happens-before ordering). +//! +//! # State field +//! +//! Each timer has a state field associated with it. This field contains either +//! the current scheduled time, or a special flag value indicating its state. +//! This state can either indicate that the timer is on the 'pending' queue (and +//! thus will be fired with an `Ok(())` result soon) or that it has already been +//! fired/deregistered. +//! +//! This single state field allows for code that is firing the timer to +//! synchronize with any racing `reset` calls reliably. +//! +//! # Cached vs true timeouts +//! +//! To allow for the use case of a timeout that is periodically reset before +//! expiration to be as lightweight as possible, we support optimistically +//! lock-free timer resets, in the case where a timer is rescheduled to a later +//! point than it was originally scheduled for. +//! +//! This is accomplished by lazily rescheduling timers. That is, we update the +//! state field field with the true expiration of the timer from the holder of +//! the [`TimerEntry`]. When the driver services timers (ie, whenever it's +//! walking lists of timers), it checks this "true when" value, and reschedules +//! based on it. +//! +//! We do, however, also need to track what the expiration time was when we +//! originally registered the timer; this is used to locate the right linked +//! list when the timer is being cancelled. This is referred to as the "cached +//! when" internally. +//! +//! There is of course a race condition between timer reset and timer +//! expiration. If the driver fails to observe the updated expiration time, it +//! could trigger expiration of the timer too early. However, because +//! `mark_pending` performs a compare-and-swap, it will identify this race and +//! refuse to mark the timer as pending. + +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::atomic::Ordering; - /// Stores the actual error. If `state` indicates that an error occurred, - /// this is guaranteed to be a non-zero value representing the first error - /// that occurred. Otherwise its value is undefined. - error: AtomicU8, - - /// Task to notify once the deadline is reached. - waker: AtomicWaker, +use crate::sync::AtomicWaker; +use crate::time::Instant; +use crate::util::linked_list; - /// True when the entry is queued in the "process" stack. This value - /// is set before pushing the value and unset after popping the value. - /// - /// TODO: This could possibly be rolled up into `state`. - pub(super) queued: AtomicBool, +use super::Handle; - /// Next entry in the "process" linked list. - /// - /// Access to this field is coordinated by the `queued` flag. - /// - /// Represents a strong Arc ref. - pub(super) next_atomic: UnsafeCell<*mut Entry>, +use std::cell::UnsafeCell as StdUnsafeCell; +use std::task::{Context, Poll, Waker}; +use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; - /// When the entry expires, relative to the `start` of the timer - /// (Inner::start). This is only used by the timer. - /// - /// A `Sleep` instance can be reset to a different deadline by the thread - /// that owns the `Sleep` instance. In this case, the timer thread will not - /// immediately know that this has happened. The timer thread must know the - /// last deadline that it saw as it uses this value to locate the entry in - /// its wheel. - /// - /// Once the timer thread observes that the instant has changed, it updates - /// the wheel and sets this value. The idea is that this value eventually - /// converges to the value of `state` as the timer thread makes updates. - when: UnsafeCell<Option<u64>>, +type TimerResult = Result<(), crate::time::error::Error>; - /// Next entry in the State's linked list. - /// - /// This is only accessed by the timer - pub(crate) next_stack: UnsafeCell<Option<Arc<Entry>>>, +const STATE_DEREGISTERED: u64 = u64::max_value(); +const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; +const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; - /// Previous entry in the State's linked list. - /// - /// This is only accessed by the timer and is used to unlink a canceled - /// entry. - /// - /// This is a weak reference. - pub(crate) prev_stack: UnsafeCell<*const Entry>, -} - -/// Stores the info for `Sleep`. +/// Not all platforms support 64-bit compare-and-swap. This hack replaces the +/// AtomicU64 with a mutex around a u64 on platforms that don't. This is slow, +/// unfortunately, but 32-bit platforms are a bit niche so it'll do for now. +/// +/// Note: We use "x86 or 64-bit pointers" as the condition here because +/// target_has_atomic is not stable. +#[cfg(all( + not(tokio_force_time_entry_locked), + any(target_arch = "x86", target_pointer_width = "64") +))] +type AtomicU64 = crate::loom::sync::atomic::AtomicU64; + +#[cfg(not(all( + not(tokio_force_time_entry_locked), + any(target_arch = "x86", target_pointer_width = "64") +)))] #[derive(Debug)] -pub(crate) struct Time { - pub(crate) deadline: Instant, - pub(crate) duration: Duration, +struct AtomicU64 { + inner: crate::loom::sync::Mutex<u64>, } -/// Flag indicating a timer entry has elapsed -const ELAPSED: u64 = 1 << 63; - -/// Flag indicating a timer entry has reached an error state -const ERROR: u64 = u64::MAX; +#[cfg(not(all( + not(tokio_force_time_entry_locked), + any(target_arch = "x86", target_pointer_width = "64") +)))] +impl AtomicU64 { + fn new(v: u64) -> Self { + Self { + inner: crate::loom::sync::Mutex::new(v), + } + } -// ===== impl Entry ===== + fn load(&self, _order: Ordering) -> u64 { + debug_assert_ne!(_order, Ordering::SeqCst); // we only provide AcqRel with the lock + *self.inner.lock() + } -impl Entry { - pub(crate) fn new(handle: &Handle, deadline: Instant, duration: Duration) -> Arc<Entry> { - let inner = handle.inner().unwrap(); + fn store(&self, v: u64, _order: Ordering) { + debug_assert_ne!(_order, Ordering::SeqCst); // we only provide AcqRel with the lock + *self.inner.lock() = v; + } - // Attempt to increment the number of active timeouts - let entry = if let Err(err) = inner.increment() { - let entry = Entry::new2(deadline, duration, Weak::new(), ERROR); - entry.error(err); - entry + fn compare_exchange( + &self, + current: u64, + new: u64, + _success: Ordering, + _failure: Ordering, + ) -> Result<u64, u64> { + debug_assert_ne!(_success, Ordering::SeqCst); // we only provide AcqRel with the lock + debug_assert_ne!(_failure, Ordering::SeqCst); + + let mut lock = self.inner.lock(); + + if *lock == current { + *lock = new; + Ok(current) } else { - let when = inner.normalize_deadline(deadline); - let state = if when <= inner.elapsed() { - ELAPSED - } else { - when - }; - Entry::new2(deadline, duration, Arc::downgrade(&inner), state) - }; - - let entry = Arc::new(entry); - if let Err(err) = inner.queue(&entry) { - entry.error(err); + Err(*lock) } - - entry } - /// Only called by `Registration` - pub(crate) fn time_ref(&self) -> &Time { - unsafe { &*self.time.0.get() } + fn compare_exchange_weak( + &self, + current: u64, + new: u64, + success: Ordering, + failure: Ordering, + ) -> Result<u64, u64> { + self.compare_exchange(current, new, success, failure) } +} - /// Only called by `Registration` - #[allow(clippy::mut_from_ref)] // https://github.com/rust-lang/rust-clippy/issues/4281 - pub(crate) unsafe fn time_mut(&self) -> &mut Time { - &mut *self.time.0.get() - } +/// This structure holds the current shared state of the timer - its scheduled +/// time (if registered), or otherwise the result of the timer completing, as +/// well as the registered waker. +/// +/// Generally, the StateCell is only permitted to be accessed from two contexts: +/// Either a thread holding the corresponding &mut TimerEntry, or a thread +/// holding the timer driver lock. The write actions on the StateCell amount to +/// passing "ownership" of the StateCell between these contexts; moving a timer +/// from the TimerEntry to the driver requires _both_ holding the &mut +/// TimerEntry and the driver lock, while moving it back (firing the timer) +/// requires only the driver lock. +pub(super) struct StateCell { + /// Holds either the scheduled expiration time for this timer, or (if the + /// timer has been fired and is unregistered), [`u64::max_value()`]. + state: AtomicU64, + /// If the timer is fired (an Acquire order read on state shows + /// `u64::max_value()`), holds the result that should be returned from + /// polling the timer. Otherwise, the contents are unspecified and reading + /// without holding the driver lock is undefined behavior. + result: UnsafeCell<TimerResult>, + /// The currently-registered waker + waker: CachePadded<AtomicWaker>, +} - pub(crate) fn when(&self) -> u64 { - self.when_internal().expect("invalid internal state") +impl Default for StateCell { + fn default() -> Self { + Self::new() } +} - /// The current entry state as known by the timer. This is not the value of - /// `state`, but lets the timer know how to converge its state to `state`. - pub(crate) fn when_internal(&self) -> Option<u64> { - unsafe { *self.when.get() } +impl std::fmt::Debug for StateCell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "StateCell({:?})", self.read_state()) } +} - pub(crate) fn set_when_internal(&self, when: Option<u64>) { - unsafe { - *self.when.get() = when; +impl StateCell { + fn new() -> Self { + Self { + state: AtomicU64::new(STATE_DEREGISTERED), + result: UnsafeCell::new(Ok(())), + waker: CachePadded(AtomicWaker::new()), } } - /// Called by `Timer` to load the current value of `state` for processing - pub(crate) fn load_state(&self) -> Option<u64> { - let state = self.state.load(SeqCst); + fn is_pending(&self) -> bool { + self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE + } + + /// Returns the current expiration time, or None if not currently scheduled. + fn when(&self) -> Option<u64> { + let cur_state = self.state.load(Ordering::Relaxed); - if is_elapsed(state) { + if cur_state == u64::max_value() { None } else { - Some(state) + Some(cur_state) } } - pub(crate) fn is_elapsed(&self) -> bool { - let state = self.state.load(SeqCst); - is_elapsed(state) + /// If the timer is completed, returns the result of the timer. Otherwise, + /// returns None and registers the waker. + fn poll(&self, waker: &Waker) -> Poll<TimerResult> { + // We must register first. This ensures that either `fire` will + // observe the new waker, or we will observe a racing fire to have set + // the state, or both. + self.waker.0.register_by_ref(waker); + + self.read_state() + } + + fn read_state(&self) -> Poll<TimerResult> { + let cur_state = self.state.load(Ordering::Acquire); + + if cur_state == STATE_DEREGISTERED { + // SAFETY: The driver has fired this timer; this involves writing + // the result, and then writing (with release ordering) the state + // field. + Poll::Ready(unsafe { self.result.with(|p| *p) }) + } else { + Poll::Pending + } } - pub(crate) fn fire(&self, when: u64) { - let mut curr = self.state.load(SeqCst); + /// Marks this timer as being moved to the pending list, if its scheduled + /// time is not after `not_after`. + /// + /// If the timer is scheduled for a time after not_after, returns an Err + /// containing the current scheduled time. + /// + /// SAFETY: Must hold the driver lock. + unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + // Quick initial debug check to see if the timer is already fired. Since + // firing the timer can only happen with the driver lock held, we know + // we shouldn't be able to "miss" a transition to a fired state, even + // with relaxed ordering. + let mut cur_state = self.state.load(Ordering::Relaxed); loop { - if is_elapsed(curr) || curr > when { - return; - } + debug_assert!(cur_state < STATE_MIN_VALUE); - let next = ELAPSED | curr; - let actual = self.state.compare_and_swap(curr, next, SeqCst); + if cur_state > not_after { + break Err(cur_state); + } - if curr == actual { - break; + match self.state.compare_exchange( + cur_state, + STATE_PENDING_FIRE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + break Ok(()); + } + Err(actual_state) => { + cur_state = actual_state; + } } + } + } - curr = actual; + /// Fires the timer, setting the result to the provided result. + /// + /// Returns: + /// * `Some(waker) - if fired and a waker needs to be invoked once the + /// driver lock is released + /// * `None` - if fired and a waker does not need to be invoked, or if + /// already fired + /// + /// SAFETY: The driver lock must be held. + unsafe fn fire(&self, result: TimerResult) -> Option<Waker> { + // Quick initial check to see if the timer is already fired. Since + // firing the timer can only happen with the driver lock held, we know + // we shouldn't be able to "miss" a transition to a fired state, even + // with relaxed ordering. + let cur_state = self.state.load(Ordering::Relaxed); + if cur_state == STATE_DEREGISTERED { + return None; } - self.waker.wake(); - } + // SAFETY: We assume the driver lock is held and the timer is not + // fired, so only the driver is accessing this field. + // + // We perform a release-ordered store to state below, to ensure this + // write is visible before the state update is visible. + unsafe { self.result.with_mut(|p| *p = result) }; - pub(crate) fn error(&self, error: Error) { - // Record the precise nature of the error, if there isn't already an - // error present. If we don't actually transition to the error state - // below, that's fine, as the error details we set here will be ignored. - self.error.compare_and_swap(0, error.as_u8(), SeqCst); + self.state.store(STATE_DEREGISTERED, Ordering::Release); + + self.waker.0.take_waker() + } - // Only transition to the error state if not currently elapsed - let mut curr = self.state.load(SeqCst); + /// Marks the timer as registered (poll will return None) and sets the + /// expiration time. + /// + /// While this function is memory-safe, it should only be called from a + /// context holding both `&mut TimerEntry` and the driver lock. + fn set_expiration(&self, timestamp: u64) { + debug_assert!(timestamp < STATE_MIN_VALUE); + + // We can use relaxed ordering because we hold the driver lock and will + // fence when we release the lock. + self.state.store(timestamp, Ordering::Relaxed); + } + /// Attempts to adjust the timer to a new timestamp. + /// + /// If the timer has already been fired, is pending firing, or the new + /// timestamp is earlier than the old timestamp, (or occasionally + /// spuriously) returns Err without changing the timer's state. In this + /// case, the timer must be deregistered and re-registered. + fn extend_expiration(&self, new_timestamp: u64) -> Result<(), ()> { + let mut prior = self.state.load(Ordering::Relaxed); loop { - if is_elapsed(curr) { - return; + if new_timestamp < prior || prior >= STATE_MIN_VALUE { + return Err(()); } - let next = ERROR; + match self.state.compare_exchange_weak( + prior, + new_timestamp, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + return Ok(()); + } + Err(true_prior) => { + prior = true_prior; + } + } + } + } - let actual = self.state.compare_and_swap(curr, next, SeqCst); + /// Returns true if the state of this timer indicates that the timer might + /// be registered with the driver. This check is performed with relaxed + /// ordering, but is conservative - if it returns false, the timer is + /// definitely _not_ registered. + pub(super) fn might_be_registered(&self) -> bool { + self.state.load(Ordering::Relaxed) != u64::max_value() + } +} - if curr == actual { - break; - } +/// A timer entry. +/// +/// This is the handle to a timer that is controlled by the requester of the +/// timer. As this participates in intrusive data structures, it must be pinned +/// before polling. +#[derive(Debug)] +pub(super) struct TimerEntry { + /// Arc reference to the driver. We can only free the driver after + /// deregistering everything from their respective timer wheels. + driver: Handle, + /// Shared inner structure; this is part of an intrusive linked list, and + /// therefore other references can exist to it while mutable references to + /// Entry exist. + /// + /// This is manipulated only under the inner mutex. TODO: Can we use loom + /// cells for this? + inner: StdUnsafeCell<TimerShared>, + /// Initial deadline for the timer. This is used to register on the first + /// poll, as we can't register prior to being pinned. + initial_deadline: Option<Instant>, + /// Ensure the type is !Unpin + _m: std::marker::PhantomPinned, +} + +unsafe impl Send for TimerEntry {} +unsafe impl Sync for TimerEntry {} + +/// An TimerHandle is the (non-enforced) "unique" pointer from the driver to the +/// timer entry. Generally, at most one TimerHandle exists for a timer at a time +/// (enforced by the timer state machine). +/// +/// SAFETY: An TimerHandle is essentially a raw pointer, and the usual caveats +/// of pointer safety apply. In particular, TimerHandle does not itself enforce +/// that the timer does still exist; however, normally an TimerHandle is created +/// immediately before registering the timer, and is consumed when firing the +/// timer, to help minimize mistakes. Still, because TimerHandle cannot enforce +/// memory safety, all operations are unsafe. +#[derive(Debug)] +pub(crate) struct TimerHandle { + inner: NonNull<TimerShared>, +} - curr = actual; +pub(super) type EntryList = crate::util::linked_list::LinkedList<TimerShared, TimerShared>; + +/// The shared state structure of a timer. This structure is shared between the +/// frontend (`Entry`) and driver backend. +/// +/// Note that this structure is located inside the `TimerEntry` structure. +#[derive(Debug)] +pub(crate) struct TimerShared { + /// Current state. This records whether the timer entry is currently under + /// the ownership of the driver, and if not, its current state (not + /// complete, fired, error, etc). + state: StateCell, + + /// Data manipulated by the driver thread itself, only. + driver_state: CachePadded<TimerSharedPadded>, + + _p: PhantomPinned, +} + +impl TimerShared { + pub(super) fn new() -> Self { + Self { + state: StateCell::default(), + driver_state: CachePadded(TimerSharedPadded::new()), + _p: PhantomPinned, } + } - self.waker.wake(); + /// Gets the cached time-of-expiration value + pub(super) fn cached_when(&self) -> u64 { + // Cached-when is only accessed under the driver lock, so we can use relaxed + self.driver_state.0.cached_when.load(Ordering::Relaxed) } - pub(crate) fn cancel(entry: &Arc<Entry>) { - let state = entry.state.fetch_or(ELAPSED, SeqCst); + /// Gets the true time-of-expiration value, and copies it into the cached + /// time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn sync_when(&self) -> u64 { + let true_when = self.true_when(); - if is_elapsed(state) { - // Nothing more to do - return; - } + self.driver_state + .0 + .cached_when + .store(true_when, Ordering::Relaxed); - // If registered with a timer instance, try to upgrade the Arc. - let inner = match entry.upgrade_inner() { - Some(inner) => inner, - None => return, - }; + true_when + } - let _ = inner.queue(entry); + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + unsafe fn set_cached_when(&self, when: u64) { + self.driver_state + .0 + .cached_when + .store(when, Ordering::Relaxed); } - pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { - let mut curr = self.state.load(SeqCst); + /// Returns the true time-of-expiration value, with relaxed memory ordering. + pub(super) fn true_when(&self) -> u64 { + self.state.when().expect("Timer already fired") + } - if is_elapsed(curr) { - return Poll::Ready(if curr == ERROR { - Err(Error::from_u8(self.error.load(SeqCst))) - } else { - Ok(()) - }); + /// Sets the true time-of-expiration value, even if it is less than the + /// current expiration or the timer is deregistered. + /// + /// SAFETY: Must only be called with the driver lock held and the entry not + /// in the timer wheel. + pub(super) unsafe fn set_expiration(&self, t: u64) { + self.state.set_expiration(t); + self.driver_state.0.cached_when.store(t, Ordering::Relaxed); + } + + /// Sets the true time-of-expiration only if it is after the current. + pub(super) fn extend_expiration(&self, t: u64) -> Result<(), ()> { + self.state.extend_expiration(t) + } + + /// Returns a TimerHandle for this timer. + pub(super) fn handle(&self) -> TimerHandle { + TimerHandle { + inner: NonNull::from(self), } + } - self.waker.register_by_ref(cx.waker()); + /// Returns true if the state of this timer indicates that the timer might + /// be registered with the driver. This check is performed with relaxed + /// ordering, but is conservative - if it returns false, the timer is + /// definitely _not_ registered. + pub(super) fn might_be_registered(&self) -> bool { + self.state.might_be_registered() + } +} - curr = self.state.load(SeqCst); +/// Additional shared state between the driver and the timer which is cache +/// padded. This contains the information that the driver thread accesses most +/// frequently to minimize contention. In particular, we move it away from the +/// waker, as the waker is updated on every poll. +struct TimerSharedPadded { + /// The expiration time for which this entry is currently registered. + /// Generally owned by the driver, but is accessed by the entry when not + /// registered. + cached_when: AtomicU64, + + /// The true expiration time. Set by the timer future, read by the driver. + true_when: AtomicU64, + + /// A link within the doubly-linked list of timers on a particular level and + /// slot. Valid only if state is equal to Registered. + /// + /// Only accessed under the entry lock. + pointers: StdUnsafeCell<linked_list::Pointers<TimerShared>>, +} - if is_elapsed(curr) { - return Poll::Ready(if curr == ERROR { - Err(Error::from_u8(self.error.load(SeqCst))) - } else { - Ok(()) - }); - } +impl std::fmt::Debug for TimerSharedPadded { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TimerSharedPadded") + .field("when", &self.true_when.load(Ordering::Relaxed)) + .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) + .finish() + } +} - Poll::Pending +impl TimerSharedPadded { + fn new() -> Self { + Self { + cached_when: AtomicU64::new(0), + true_when: AtomicU64::new(0), + pointers: StdUnsafeCell::new(linked_list::Pointers::new()), + } } +} - /// Only called by `Registration` - pub(crate) fn reset(entry: &mut Arc<Entry>) { - let inner = match entry.upgrade_inner() { - Some(inner) => inner, - None => return, - }; +unsafe impl Send for TimerShared {} +unsafe impl Sync for TimerShared {} - let deadline = entry.time_ref().deadline; - let when = inner.normalize_deadline(deadline); - let elapsed = inner.elapsed(); +unsafe impl linked_list::Link for TimerShared { + type Handle = TimerHandle; - let next = if when <= elapsed { ELAPSED } else { when }; + type Target = TimerShared; - let mut curr = entry.state.load(SeqCst); + fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target> { + handle.inner + } - loop { - // In these two cases, there is no work to do when resetting the - // timer. If the `Entry` is in an error state, then it cannot be - // used anymore. If resetting the entry to the current value, then - // the reset is a noop. - if curr == ERROR || curr == when { - return; - } + unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle { + TimerHandle { inner: ptr } + } - let actual = entry.state.compare_and_swap(curr, next, SeqCst); + unsafe fn pointers( + target: NonNull<Self::Target>, + ) -> NonNull<linked_list::Pointers<Self::Target>> { + unsafe { NonNull::new(target.as_ref().driver_state.0.pointers.get()).unwrap() } + } +} - if curr == actual { - break; - } +// ===== impl Entry ===== + +impl TimerEntry { + pub(crate) fn new(handle: &Handle, deadline: Instant) -> Self { + let driver = handle.clone(); - curr = actual; + Self { + driver, + inner: StdUnsafeCell::new(TimerShared::new()), + initial_deadline: Some(deadline), + _m: std::marker::PhantomPinned, } + } + + fn inner(&self) -> &TimerShared { + unsafe { &*self.inner.get() } + } + + pub(crate) fn is_elapsed(&self) -> bool { + !self.inner().state.might_be_registered() && self.initial_deadline.is_none() + } + + /// Cancels and deregisters the timer. This operation is irreversible. + pub(crate) fn cancel(self: Pin<&mut Self>) { + // We need to perform an acq/rel fence with the driver thread, and the + // simplest way to do so is to grab the driver lock. + // + // Why is this necessary? We're about to release this timer's memory for + // some other non-timer use. However, we've been doing a bunch of + // relaxed (or even non-atomic) writes from the driver thread, and we'll + // be doing more from _this thread_ (as this memory is interpreted as + // something else). + // + // It is critical to ensure that, from the point of view of the driver, + // those future non-timer writes happen-after the timer is fully fired, + // and from the purpose of this thread, the driver's writes all + // happen-before we drop the timer. This in turn requires us to perform + // an acquire-release barrier in _both_ directions between the driver + // and dropping thread. + // + // The lock acquisition in clear_entry serves this purpose. All of the + // driver manipulations happen with the lock held, so we can just take + // the lock and be sure that this drop happens-after everything the + // driver did so far and happens-before everything the driver does in + // the future. While we have the lock held, we also go ahead and + // deregister the entry if necessary. + unsafe { self.driver.clear_entry(NonNull::from(self.inner())) }; + } + + pub(crate) fn reset(mut self: Pin<&mut Self>, new_time: Instant) { + unsafe { self.as_mut().get_unchecked_mut() }.initial_deadline = None; - // If the state has transitioned to 'elapsed' then wake the task as - // this entry is ready to be polled. - if !is_elapsed(curr) && is_elapsed(next) { - entry.waker.wake(); + let tick = self.driver.time_source().deadline_to_tick(new_time); + + if self.inner().extend_expiration(tick).is_ok() { + return; } - // The driver tracks all non-elapsed entries; notify the driver that it - // should update its state for this entry unless the entry had already - // elapsed and remains elapsed. - if !is_elapsed(curr) || !is_elapsed(next) { - let _ = inner.queue(entry); + unsafe { + self.driver.reregister(tick, self.inner().into()); } } - fn new2(deadline: Instant, duration: Duration, inner: Weak<Inner>, state: u64) -> Self { - Self { - time: CachePadded(UnsafeCell::new(Time { deadline, duration })), - inner, - waker: AtomicWaker::new(), - state: AtomicU64::new(state), - queued: AtomicBool::new(false), - error: AtomicU8::new(0), - next_atomic: UnsafeCell::new(ptr::null_mut()), - when: UnsafeCell::new(None), - next_stack: UnsafeCell::new(None), - prev_stack: UnsafeCell::new(ptr::null_mut()), + pub(crate) fn poll_elapsed( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), super::Error>> { + if let Some(deadline) = self.initial_deadline { + self.as_mut().reset(deadline); } - } - fn upgrade_inner(&self) -> Option<Arc<Inner>> { - self.inner.upgrade() + let this = unsafe { self.get_unchecked_mut() }; + + this.inner().state.poll(cx.waker()) } } -fn is_elapsed(state: u64) -> bool { - state & ELAPSED == ELAPSED -} +impl TimerHandle { + pub(super) unsafe fn cached_when(&self) -> u64 { + unsafe { self.inner.as_ref().cached_when() } + } -impl Drop for Entry { - fn drop(&mut self) { - let inner = match self.upgrade_inner() { - Some(inner) => inner, - None => return, - }; + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } + } - inner.decrement(); + pub(super) unsafe fn is_pending(&self) -> bool { + unsafe { self.inner.as_ref().state.is_pending() } + } + + /// Forcibly sets the true and cached expiration times to the given tick. + /// + /// SAFETY: The caller must ensure that the handle remains valid, the driver + /// lock is held, and that the timer is not in any wheel linked lists. + pub(super) unsafe fn set_expiration(&self, tick: u64) { + self.inner.as_ref().set_expiration(tick); + } + + /// Attempts to mark this entry as pending. If the expiration time is after + /// `not_after`, however, returns an Err with the current expiration time. + /// + /// If an `Err` is returned, the `cached_when` value will be updated to this + /// new expiration time. + /// + /// SAFETY: The caller must ensure that the handle remains valid, the driver + /// lock is held, and that the timer is not in any wheel linked lists. + /// After returning Ok, the entry must be added to the pending list. + pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + match self.inner.as_ref().state.mark_pending(not_after) { + Ok(()) => { + // mark this as being on the pending queue in cached_when + self.inner.as_ref().set_cached_when(u64::max_value()); + Ok(()) + } + Err(tick) => { + self.inner.as_ref().set_cached_when(tick); + Err(tick) + } + } + } + + /// Attempts to transition to a terminal state. If the state is already a + /// terminal state, does nothing. + /// + /// Because the entry might be dropped after the state is moved to a + /// terminal state, this function consumes the handle to ensure we don't + /// access the entry afterwards. + /// + /// Returns the last-registered waker, if any. + /// + /// SAFETY: The driver lock must be held while invoking this function, and + /// the entry must not be in any wheel linked lists. + pub(super) unsafe fn fire(self, completed_state: TimerResult) -> Option<Waker> { + self.inner.as_ref().state.fire(completed_state) } } -unsafe impl Send for Entry {} -unsafe impl Sync for Entry {} +impl Drop for TimerEntry { + fn drop(&mut self) { + unsafe { Pin::new_unchecked(self) }.as_mut().cancel() + } +} #[cfg_attr(target_arch = "x86_64", repr(align(128)))] #[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))] -#[derive(Debug)] +#[derive(Debug, Default)] struct CachePadded<T>(T); diff --git a/src/time/driver/handle.rs b/src/time/driver/handle.rs index 54b8a8b..bfc49fb 100644 --- a/src/time/driver/handle.rs +++ b/src/time/driver/handle.rs @@ -1,22 +1,29 @@ -use crate::time::driver::Inner; +use crate::loom::sync::{Arc, Mutex}; +use crate::time::driver::ClockTime; use std::fmt; -use std::sync::{Arc, Weak}; /// Handle to time driver instance. #[derive(Clone)] pub(crate) struct Handle { - inner: Weak<Inner>, + time_source: ClockTime, + inner: Arc<Mutex<super::Inner>>, } impl Handle { /// Creates a new timer `Handle` from a shared `Inner` timer state. - pub(crate) fn new(inner: Weak<Inner>) -> Self { - Handle { inner } + pub(super) fn new(inner: Arc<Mutex<super::Inner>>) -> Self { + let time_source = inner.lock().time_source.clone(); + Handle { time_source, inner } } - /// Tries to return a strong ref to the inner - pub(crate) fn inner(&self) -> Option<Arc<Inner>> { - self.inner.upgrade() + /// Returns the time source associated with this handle + pub(super) fn time_source(&self) -> &ClockTime { + &self.time_source + } + + /// Locks the driver's inner structure + pub(super) fn lock(&self) -> crate::loom::sync::MutexGuard<'_, super::Inner> { + self.inner.lock() } } @@ -31,12 +38,12 @@ cfg_rt! { /// It can be triggered when `Builder::enable_time()` or /// `Builder::enable_all()` are not included in the builder. /// - /// It can also panic whenever a timer is created outside of a Tokio - /// runtime. That is why `rt.block_on(delay_for(...))` will panic, + /// It can also panic whenever a timer is created outside of a + /// Tokio runtime. That is why `rt.block_on(delay_for(...))` will panic, /// since the function is executed outside of the runtime. - /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't - /// panic. And this is because wrapping the function on an async makes it - /// lazy, and so gets executed inside the runtime successfuly without + /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't panic. + /// And this is because wrapping the function on an async makes it lazy, + /// and so gets executed inside the runtime successfuly without /// panicking. pub(crate) fn current() -> Self { crate::runtime::context::time_handle() @@ -61,7 +68,7 @@ cfg_not_rt! { /// since the function is executed outside of the runtime. /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't /// panic. And this is because wrapping the function on an async makes it - /// lazy, and so gets executed inside the runtime successfuly without + /// lazy, and so outside executed inside the runtime successfuly without /// panicking. pub(crate) fn current() -> Self { panic!("there is no timer running, must be called from the context of Tokio runtime or \ diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs index 8532c55..9fbc0b3 100644 --- a/src/time/driver/mod.rs +++ b/src/time/driver/mod.rs @@ -1,26 +1,29 @@ +// Currently, rust warns when an unsafe fn contains an unsafe {} block. However, +// in the future, this will change to the reverse. For now, suppress this +// warning and generally stick with being explicit about unsafety. +#![allow(unused_unsafe)] #![cfg_attr(not(feature = "rt"), allow(dead_code))] //! Time driver -mod atomic_stack; -use self::atomic_stack::AtomicStack; - mod entry; -pub(super) use self::entry::Entry; +pub(self) use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared}; mod handle; pub(crate) use self::handle::Handle; -use crate::loom::sync::atomic::{AtomicU64, AtomicUsize}; +mod wheel; + +pub(super) mod sleep; + +use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; -use crate::time::{error::Error, wheel}; +use crate::time::error::Error; use crate::time::{Clock, Duration, Instant}; -use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst}; - -use std::sync::Arc; -use std::usize; -use std::{cmp, fmt}; +use std::convert::TryInto; +use std::fmt; +use std::{num::NonZeroU64, ptr::NonNull, task::Waker}; /// Time implementation that drives [`Sleep`][sleep], [`Interval`][interval], and [`Timeout`][timeout]. /// @@ -78,63 +81,96 @@ use std::{cmp, fmt}; /// [timeout]: crate::time::Timeout /// [interval]: crate::time::Interval #[derive(Debug)] -pub(crate) struct Driver<T: Park> { +pub(crate) struct Driver<P: Park + 'static> { + /// Timing backend in use + time_source: ClockTime, + /// Shared state - inner: Arc<Inner>, + inner: Handle, - /// Timer wheel - wheel: wheel::Wheel, + /// Parker to delegate to + park: P, +} + +/// A structure which handles conversion from Instants to u64 timestamps. +#[derive(Debug, Clone)] +pub(self) struct ClockTime { + clock: super::clock::Clock, + start_time: Instant, +} - /// Thread parker. The `Driver` park implementation delegates to this. - park: T, +impl ClockTime { + pub(self) fn new(clock: Clock) -> Self { + Self { + clock, + start_time: super::clock::now(), + } + } - /// Source of "now" instances - clock: Clock, + pub(self) fn deadline_to_tick(&self, t: Instant) -> u64 { + // Round up to the end of a ms + self.instant_to_tick(t + Duration::from_nanos(999_999)) + } - /// True if the driver is being shutdown - is_shutdown: bool, + pub(self) fn instant_to_tick(&self, t: Instant) -> u64 { + // round up + let dur: Duration = t + .checked_duration_since(self.start_time) + .unwrap_or_else(|| Duration::from_secs(0)); + let ms = dur.as_millis(); + + ms.try_into().expect("Duration too far into the future") + } + + pub(self) fn tick_to_duration(&self, t: u64) -> Duration { + Duration::from_millis(t) + } + + pub(self) fn now(&self) -> u64 { + self.instant_to_tick(self.clock.now()) + } } /// Timer state shared between `Driver`, `Handle`, and `Registration`. -pub(crate) struct Inner { - /// The instant at which the timer started running. - start: Instant, +pub(self) struct Inner { + /// Timing backend in use + time_source: ClockTime, /// The last published timer `elapsed` value. - elapsed: AtomicU64, + elapsed: u64, - /// Number of active timeouts - num: AtomicUsize, + /// The earliest time at which we promise to wake up without unparking + next_wake: Option<NonZeroU64>, - /// Head of the "process" linked list. - process: AtomicStack, + /// Timer wheel + wheel: wheel::Wheel, + + /// True if the driver is being shutdown + is_shutdown: bool, - /// Unparks the timer thread. + /// Unparker that can be used to wake the time driver unpark: Box<dyn Unpark>, } -/// Maximum number of timeouts the system can handle concurrently. -const MAX_TIMEOUTS: usize = usize::MAX >> 1; - // ===== impl Driver ===== -impl<T> Driver<T> +impl<P> Driver<P> where - T: Park, + P: Park + 'static, { /// Creates a new `Driver` instance that uses `park` to block the current - /// thread and `clock` to get the current `Instant`. + /// thread and `time_source` to get the current time and convert to ticks. /// /// Specifying the source of time is useful when testing. - pub(crate) fn new(park: T, clock: Clock) -> Driver<T> { - let unpark = Box::new(park.unpark()); + pub(crate) fn new(park: P, clock: Clock) -> Driver<P> { + let time_source = ClockTime::new(clock); + + let inner = Inner::new(time_source.clone(), Box::new(park.unpark())); Driver { - inner: Arc::new(Inner::new(clock.now(), unpark)), - wheel: wheel::Wheel::new(), + time_source, + inner: Handle::new(Arc::new(Mutex::new(inner))), park, - clock, - is_shutdown: false, } } @@ -145,189 +181,242 @@ where /// `with_default`, setting the timer as the default timer for the execution /// context. pub(crate) fn handle(&self) -> Handle { - Handle::new(Arc::downgrade(&self.inner)) + self.inner.clone() } - /// Converts an `Expiration` to an `Instant`. - fn expiration_instant(&self, when: u64) -> Instant { - self.inner.start + Duration::from_millis(when) - } + fn park_internal(&mut self, limit: Option<Duration>) -> Result<(), P::Error> { + let clock = &self.time_source.clock; - /// Runs timer related logic - fn process(&mut self) { - let now = crate::time::ms( - self.clock.now() - self.inner.start, - crate::time::Round::Down, - ); + let mut lock = self.inner.lock(); - while let Some(entry) = self.wheel.poll(now) { - let when = entry.when_internal().expect("invalid internal entry state"); + assert!(!lock.is_shutdown); - // Fire the entry - entry.fire(when); + let next_wake = lock.wheel.next_expiration_time(); + lock.next_wake = + next_wake.map(|t| NonZeroU64::new(t).unwrap_or_else(|| NonZeroU64::new(1).unwrap())); - // Track that the entry has been fired - entry.set_when_internal(None); - } + drop(lock); - // Update the elapsed cache - self.inner.elapsed.store(self.wheel.elapsed(), SeqCst); - } + match next_wake { + Some(when) => { + let now = self.time_source.now(); + // Note that we effectively round up to 1ms here - this avoids + // very short-duration microsecond-resolution sleeps that the OS + // might treat as zero-length. + let mut duration = self.time_source.tick_to_duration(when.saturating_sub(now)); + + if duration > Duration::from_millis(0) { + if let Some(limit) = limit { + duration = std::cmp::min(limit, duration); + } - /// Processes the entry queue - /// - /// This handles adding and canceling timeouts. - fn process_queue(&mut self) { - for entry in self.inner.process.take() { - match (entry.when_internal(), entry.load_state()) { - (None, None) => { - // Nothing to do - } - (Some(_), None) => { - // Remove the entry - self.clear_entry(&entry); - } - (None, Some(when)) => { - // Add the entry to the timer wheel - self.add_entry(entry, when); + if clock.is_paused() { + self.park.park_timeout(Duration::from_secs(0))?; + + // Simulate advancing time + clock.advance(duration); + } else { + self.park.park_timeout(duration)?; + } + } else { + self.park.park_timeout(Duration::from_secs(0))?; } - (Some(_), Some(next)) => { - self.clear_entry(&entry); - self.add_entry(entry, next); + } + None => { + if let Some(duration) = limit { + if clock.is_paused() { + self.park.park_timeout(Duration::from_secs(0))?; + clock.advance(duration); + } else { + self.park.park_timeout(duration)?; + } + } else { + self.park.park()?; } } } - } - - fn clear_entry(&mut self, entry: &Arc<Entry>) { - self.wheel.remove(entry); - entry.set_when_internal(None); - } - - /// Fires the entry if it needs to, otherwise queue it to be processed later. - fn add_entry(&mut self, entry: Arc<Entry>, when: u64) { - use crate::time::error::InsertError; - entry.set_when_internal(Some(when)); + // Process pending timers after waking up + self.inner.process(); - match self.wheel.insert(when, entry) { - Ok(_) => {} - Err((entry, InsertError::Elapsed)) => { - // The entry's deadline has elapsed, so fire it and update the - // internal state accordingly. - entry.set_when_internal(None); - entry.fire(when); - } - Err((entry, InsertError::Invalid)) => { - // The entry's deadline is invalid, so error it and update the - // internal state accordingly. - entry.set_when_internal(None); - entry.error(Error::invalid()); - } - } + Ok(()) } } -impl<T> Park for Driver<T> -where - T: Park, -{ - type Unpark = T::Unpark; - type Error = T::Error; +impl Handle { + /// Runs timer related logic, and returns the next wakeup time + pub(self) fn process(&self) { + let now = self.time_source().now(); - fn unpark(&self) -> Self::Unpark { - self.park.unpark() + self.process_at_time(now) } - fn park(&mut self) -> Result<(), Self::Error> { - self.process_queue(); + pub(self) fn process_at_time(&self, now: u64) { + let mut waker_list: [Option<Waker>; 32] = Default::default(); + let mut waker_idx = 0; - match self.wheel.poll_at() { - Some(when) => { - let now = self.clock.now(); - let deadline = self.expiration_instant(when); + let mut lock = self.lock(); - if deadline > now { - let dur = deadline - now; + assert!(now >= lock.elapsed); - if self.clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - self.clock.advance(dur); - } else { - self.park.park_timeout(dur)?; + while let Some(entry) = lock.wheel.poll(now) { + debug_assert!(unsafe { entry.is_pending() }); + + // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. + if let Some(waker) = unsafe { entry.fire(Ok(())) } { + waker_list[waker_idx] = Some(waker); + + waker_idx += 1; + + if waker_idx == waker_list.len() { + // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped. + drop(lock); + + for waker in waker_list.iter_mut() { + waker.take().unwrap().wake(); } - } else { - self.park.park_timeout(Duration::from_secs(0))?; + + waker_idx = 0; + + lock = self.lock(); } } - None => { - self.park.park()?; - } } - self.process(); + // Update the elapsed cache + lock.elapsed = lock.wheel.elapsed(); + lock.next_wake = lock + .wheel + .poll_at() + .map(|t| NonZeroU64::new(t).unwrap_or_else(|| NonZeroU64::new(1).unwrap())); - Ok(()) + drop(lock); + + for waker in waker_list[0..waker_idx].iter_mut() { + waker.take().unwrap().wake(); + } } - fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { - self.process_queue(); + /// Removes a registered timer from the driver. + /// + /// The timer will be moved to the cancelled state. Wakers will _not_ be + /// invoked. If the timer is already completed, this function is a no-op. + /// + /// This function always acquires the driver lock, even if the entry does + /// not appear to be registered. + /// + /// SAFETY: The timer must not be registered with some other driver, and + /// `add_entry` must not be called concurrently. + pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) { + unsafe { + let mut lock = self.lock(); + + if entry.as_ref().might_be_registered() { + lock.wheel.remove(entry); + } - match self.wheel.poll_at() { - Some(when) => { - let now = self.clock.now(); - let deadline = self.expiration_instant(when); + entry.as_ref().handle().fire(Ok(())); + } + } - if deadline > now { - let duration = cmp::min(deadline - now, duration); + /// Removes and re-adds an entry to the driver. + /// + /// SAFETY: The timer must be either unregistered, or registered with this + /// driver. No other threads are allowed to concurrently manipulate the + /// timer at all (the current thread should hold an exclusive reference to + /// the `TimerEntry`) + pub(self) unsafe fn reregister(&self, new_tick: u64, entry: NonNull<TimerShared>) { + let waker = unsafe { + let mut lock = self.lock(); + + // We may have raced with a firing/deregistration, so check before + // deregistering. + if unsafe { entry.as_ref().might_be_registered() } { + lock.wheel.remove(entry); + } - if self.clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - self.clock.advance(duration); - } else { - self.park.park_timeout(duration)?; + // Now that we have exclusive control of this entry, mint a handle to reinsert it. + let entry = entry.as_ref().handle(); + + if lock.is_shutdown { + unsafe { entry.fire(Err(crate::time::error::Error::shutdown())) } + } else { + entry.set_expiration(new_tick); + + // Note: We don't have to worry about racing with some other resetting + // thread, because add_entry and reregister require exclusive control of + // the timer entry. + match unsafe { lock.wheel.insert(entry) } { + Ok(when) => { + if lock + .next_wake + .map(|next_wake| when < next_wake.get()) + .unwrap_or(true) + { + lock.unpark.unpark(); + } + + None } - } else { - self.park.park_timeout(Duration::from_secs(0))?; + Err((entry, super::error::InsertError::Elapsed)) => unsafe { + entry.fire(Ok(())) + }, } } - None => { - self.park.park_timeout(duration)?; - } + + // Must release lock before invoking waker to avoid the risk of deadlock. + }; + + // The timer was fired synchronously as a result of the reregistration. + // Wake the waker; this is needed because we might reset _after_ a poll, + // and otherwise the task won't be awoken to poll again. + if let Some(waker) = waker { + waker.wake(); } + } +} - self.process(); +impl<P> Park for Driver<P> +where + P: Park + 'static, +{ + type Unpark = P::Unpark; + type Error = P::Error; - Ok(()) + fn unpark(&self) -> Self::Unpark { + self.park.unpark() + } + + fn park(&mut self) -> Result<(), Self::Error> { + self.park_internal(None) + } + + fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { + self.park_internal(Some(duration)) } fn shutdown(&mut self) { - if self.is_shutdown { + let mut lock = self.inner.lock(); + + if lock.is_shutdown { return; } - use std::u64; + lock.is_shutdown = true; - // Shutdown the stack of entries to process, preventing any new entries - // from being pushed. - self.inner.process.shutdown(); + drop(lock); - // Clear the wheel, using u64::MAX allows us to drain everything - let end_of_time = u64::MAX; + // Advance time forward to the end of time. - while let Some(entry) = self.wheel.poll(end_of_time) { - entry.error(Error::shutdown()); - } + self.inner.process_at_time(u64::MAX); self.park.shutdown(); - - self.is_shutdown = true; } } -impl<T> Drop for Driver<T> +impl<P> Drop for Driver<P> where - T: Park, + P: Park + 'static, { fn drop(&mut self) { self.shutdown(); @@ -337,69 +426,16 @@ where // ===== impl Inner ===== impl Inner { - fn new(start: Instant, unpark: Box<dyn Unpark>) -> Inner { + pub(self) fn new(time_source: ClockTime, unpark: Box<dyn Unpark>) -> Self { Inner { - num: AtomicUsize::new(0), - elapsed: AtomicU64::new(0), - process: AtomicStack::new(), - start, + time_source, + elapsed: 0, + next_wake: None, unpark, + wheel: wheel::Wheel::new(), + is_shutdown: false, } } - - fn elapsed(&self) -> u64 { - self.elapsed.load(SeqCst) - } - - #[cfg(all(test, loom))] - fn num(&self, ordering: std::sync::atomic::Ordering) -> usize { - self.num.load(ordering) - } - - /// Increments the number of active timeouts - fn increment(&self) -> Result<(), Error> { - let mut curr = self.num.load(Relaxed); - loop { - if curr == MAX_TIMEOUTS { - return Err(Error::at_capacity()); - } - - match self - .num - .compare_exchange_weak(curr, curr + 1, Release, Relaxed) - { - Ok(_) => return Ok(()), - Err(next) => curr = next, - } - } - } - - /// Decrements the number of active timeouts - fn decrement(&self) { - let prev = self.num.fetch_sub(1, Acquire); - debug_assert!(prev <= MAX_TIMEOUTS); - } - - /// add the entry to the "process queue". entries are not immediately - /// pushed into the timer wheel but are instead pushed into the - /// process queue and then moved from the process queue into the timer - /// wheel on next `process` - fn queue(&self, entry: &Arc<Entry>) -> Result<(), Error> { - if self.process.push(entry)? { - // The timer is notified so that it can process the timeout - self.unpark.unpark(); - } - - Ok(()) - } - - fn normalize_deadline(&self, deadline: Instant) -> u64 { - if deadline < self.start { - return 0; - } - - crate::time::ms(deadline - self.start, crate::time::Round::Up) - } } impl fmt::Debug for Inner { @@ -408,5 +444,5 @@ impl fmt::Debug for Inner { } } -#[cfg(all(test, loom))] +#[cfg(test)] mod tests; diff --git a/src/time/sleep.rs b/src/time/driver/sleep.rs index 2bd4eb1..69a6e6d 100644 --- a/src/time/sleep.rs +++ b/src/time/driver/sleep.rs @@ -1,9 +1,9 @@ -use crate::time::driver::{Entry, Handle}; +use crate::time::driver::{Handle, TimerEntry}; use crate::time::{error::Error, Duration, Instant}; +use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; use std::task::{self, Poll}; /// Waits until `deadline` is reached. @@ -17,7 +17,7 @@ use std::task::{self, Poll}; /// Canceling a sleep instance is done by dropping the returned future. No additional /// cleanup work is required. pub fn sleep_until(deadline: Instant) -> Sleep { - Sleep::new_timeout(deadline, Duration::from_millis(0)) + Sleep::new_timeout(deadline) } /// Waits until `duration` has elapsed. @@ -57,28 +57,31 @@ pub fn sleep(duration: Duration) -> Sleep { sleep_until(Instant::now() + duration) } -/// Future returned by [`sleep`](sleep) and -/// [`sleep_until`](sleep_until). -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Sleep { - /// The link between the `Sleep` instance and the timer that drives it. - /// - /// This also stores the `deadline` value. - entry: Arc<Entry>, +pin_project! { + /// Future returned by [`sleep`](sleep) and + /// [`sleep_until`](sleep_until). + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Sleep { + deadline: Instant, + + // The link between the `Sleep` instance and the timer that drives it. + #[pin] + entry: TimerEntry, + } } impl Sleep { - pub(crate) fn new_timeout(deadline: Instant, duration: Duration) -> Sleep { + pub(crate) fn new_timeout(deadline: Instant) -> Sleep { let handle = Handle::current(); - let entry = Entry::new(&handle, deadline, duration); + let entry = TimerEntry::new(&handle, deadline); - Sleep { entry } + Sleep { deadline, entry } } /// Returns the instant at which the future will complete. pub fn deadline(&self) -> Instant { - self.entry.time_ref().deadline + self.deadline } /// Returns `true` if `Sleep` has elapsed. @@ -95,19 +98,19 @@ impl Sleep { /// /// This function can be called both before and after the future has /// completed. - pub fn reset(&mut self, deadline: Instant) { - unsafe { - self.entry.time_mut().deadline = deadline; - } - - Entry::reset(&mut self.entry); + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + let me = self.project(); + me.entry.reset(deadline); + *me.deadline = deadline; } - fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + fn poll_elapsed(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + let me = self.project(); + // Keep track of task budget let coop = ready!(crate::coop::poll_proceed(cx)); - self.entry.poll_elapsed(cx).map(move |r| { + me.entry.poll_elapsed(cx).map(move |r| { coop.made_progress(); r }) @@ -117,7 +120,7 @@ impl Sleep { impl Future for Sleep { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { // `poll_elapsed` can return an error in two cases: // // - AtCapacity: this is a pathological case where far too many @@ -127,15 +130,9 @@ impl Future for Sleep { // Both cases are extremely rare, and pretty accurately fit into // "logic errors", so we just panic in this case. A user couldn't // really do much better if we passed the error onwards. - match ready!(self.poll_elapsed(cx)) { + match ready!(self.as_mut().poll_elapsed(cx)) { Ok(()) => Poll::Ready(()), Err(e) => panic!("timer error: {}", e), } } } - -impl Drop for Sleep { - fn drop(&mut self) { - Entry::cancel(&self.entry); - } -} diff --git a/src/time/driver/tests/mod.rs b/src/time/driver/tests/mod.rs index 88ff552..cfefed3 100644 --- a/src/time/driver/tests/mod.rs +++ b/src/time/driver/tests/mod.rs @@ -1,18 +1,249 @@ -use crate::park::Unpark; -use crate::time::driver::Inner; -use crate::time::Instant; +use std::{task::Context, time::Duration}; -use loom::thread; +#[cfg(not(loom))] +use futures::task::noop_waker_ref; -use std::sync::atomic::Ordering; -use std::sync::Arc; +use crate::loom::sync::{Arc, Mutex}; +use crate::loom::thread; +use crate::{ + loom::sync::atomic::{AtomicBool, Ordering}, + park::Unpark, +}; -struct MockUnpark; +use super::{Handle, TimerEntry}; +struct MockUnpark {} impl Unpark for MockUnpark { fn unpark(&self) {} } +impl MockUnpark { + fn mock() -> Box<dyn Unpark> { + Box::new(Self {}) + } +} + +fn block_on<T>(f: impl std::future::Future<Output = T>) -> T { + #[cfg(loom)] + return loom::future::block_on(f); + + #[cfg(not(loom))] + return futures::executor::block_on(f); +} + +fn model(f: impl Fn() + Send + Sync + 'static) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn single_timer() { + model(|| { + let clock = crate::time::clock::Clock::new(true); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + }); + + thread::yield_now(); + + // This may or may not return Some (depending on how it races with the + // thread). If it does return None, however, the timer should complete + // synchronously. + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn drop_timer() { + model(|| { + let clock = crate::time::clock::Clock::new(true); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + }); + + thread::yield_now(); + + // advance 2s in the future. + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn change_waker() { + model(|| { + let clock = crate::time::clock::Clock::new(true); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let handle_ = handle.clone(); + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, clock.now() + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + }); + + thread::yield_now(); + + // advance 2s + handle.process_at_time(time_source.now() + 2_000_000_000); + + jh.join().unwrap(); + }) +} + +#[test] +fn reset_future() { + model(|| { + let finished_early = Arc::new(AtomicBool::new(false)); + + let clock = crate::time::clock::Clock::new(true); + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source.clone(), MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let handle_ = handle.clone(); + let finished_early_ = finished_early.clone(); + let start = clock.now(); + + let jh = thread::spawn(move || { + let entry = TimerEntry::new(&handle_, start + Duration::from_secs(1)); + pin!(entry); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(futures::task::noop_waker_ref())); + + entry.as_mut().reset(start + Duration::from_secs(2)); + + // shouldn't complete before 2s + block_on(futures::future::poll_fn(|cx| { + entry.as_mut().poll_elapsed(cx) + })) + .unwrap(); + + finished_early_.store(true, Ordering::Relaxed); + }); + + thread::yield_now(); + + // This may or may not return a wakeup time. + handle.process_at_time(time_source.instant_to_tick(start + Duration::from_millis(1500))); + + assert!(!finished_early.load(Ordering::Relaxed)); + + handle.process_at_time(time_source.instant_to_tick(start + Duration::from_millis(2500))); + + jh.join().unwrap(); + + assert!(finished_early.load(Ordering::Relaxed)); + }) +} + +#[test] +#[cfg(not(loom))] +fn poll_process_levels() { + let clock = crate::time::clock::Clock::new(true); + clock.pause(); + + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source, MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let mut entries = vec![]; + + for i in 0..1024 { + let mut entry = Box::pin(TimerEntry::new( + &handle, + clock.now() + Duration::from_millis(i), + )); + + let _ = entry + .as_mut() + .poll_elapsed(&mut Context::from_waker(noop_waker_ref())); + + entries.push(entry); + } + + for t in 1..1024 { + handle.process_at_time(t as u64); + for (deadline, future) in entries.iter_mut().enumerate() { + let mut context = Context::from_waker(noop_waker_ref()); + if deadline <= t { + assert!(future.as_mut().poll_elapsed(&mut context).is_ready()); + } else { + assert!(future.as_mut().poll_elapsed(&mut context).is_pending()); + } + } + } +} + +#[test] +#[cfg(not(loom))] +fn poll_process_levels_targeted() { + let mut context = Context::from_waker(noop_waker_ref()); + + let clock = crate::time::clock::Clock::new(true); + clock.pause(); + + let time_source = super::ClockTime::new(clock.clone()); + + let inner = super::Inner::new(time_source, MockUnpark::mock()); + let handle = Handle::new(Arc::new(Mutex::new(inner))); + + let e1 = TimerEntry::new(&handle, clock.now() + Duration::from_millis(193)); + pin!(e1); + + handle.process_at_time(62); + assert!(e1.as_mut().poll_elapsed(&mut context).is_pending()); + handle.process_at_time(192); + handle.process_at_time(192); +} +/* #[test] fn balanced_incr_and_decr() { const OPS: usize = 5; @@ -53,3 +284,4 @@ fn balanced_incr_and_decr() { assert_eq!(inner.num(Ordering::SeqCst), 0); }) } +*/ diff --git a/src/time/wheel/level.rs b/src/time/driver/wheel/level.rs index d51d26a..58280b1 100644 --- a/src/time/wheel/level.rs +++ b/src/time/driver/wheel/level.rs @@ -1,7 +1,8 @@ -use super::{Item, OwnedItem}; -use crate::time::wheel::Stack; +use crate::time::driver::TimerHandle; -use std::fmt; +use crate::time::driver::{EntryList, TimerShared}; + +use std::{fmt, ptr::NonNull}; /// Wheel for a single level in the timer. This wheel contains 64 slots. pub(crate) struct Level { @@ -16,8 +17,8 @@ pub(crate) struct Level { /// The least-significant bit represents slot zero. occupied: u64, - /// Slots - slot: [Stack; LEVEL_MULT], + /// Slots. We access these via the EntryInner `current_list` as well, so this needs to be an UnsafeCell. + slot: [EntryList; LEVEL_MULT], } /// Indicates when a slot must be processed next. @@ -52,7 +53,7 @@ impl Level { // However, that is only supported for arrays of size // 32 or fewer. So in our case we have to explicitly // invoke the constructor for each array element. - let ctor = Stack::default; + let ctor = || EntryList::default(); Level { level, @@ -144,14 +145,38 @@ impl Level { // TODO: This can probably be simplified w/ power of 2 math let level_start = now - (now % level_range); - let deadline = level_start + slot as u64 * slot_range; + let mut deadline = level_start + slot as u64 * slot_range; + + if deadline <= now { + // A timer is in a slot "prior" to the current time. This can occur + // because we do not have an infinite hierarchy of timer levels, and + // eventually a timer scheduled for a very distant time might end up + // being placed in a slot that is beyond the end of all of the + // arrays. + // + // To deal with this, we first limit timers to being scheduled no + // more than MAX_DURATION ticks in the future; that is, they're at + // most one rotation of the top level away. Then, we force timers + // that logically would go into the top+1 level, to instead go into + // the top level's slots. + // + // What this means is that the top level's slots act as a + // pseudo-ring buffer, and we rotate around them indefinitely. If we + // compute a deadline before now, and it's the top level, it + // therefore means we're actually looking at a slot in the future. + debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + + deadline += level_range; + } debug_assert!( deadline >= now, - "deadline={}; now={}; level={}; slot={}; occupied={:b}", + "deadline={:016X}; now={:016X}; level={}; lr={:016X}, sr={:016X}, slot={}; occupied={:b}", deadline, now, self.level, + level_range, + slot_range, slot, self.occupied ); @@ -177,18 +202,18 @@ impl Level { Some(slot) } - pub(crate) fn add_entry(&mut self, when: u64, item: OwnedItem) { - let slot = slot_for(when, self.level); + pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { + let slot = slot_for(item.cached_when(), self.level); + + self.slot[slot].push_front(item); - self.slot[slot].push(item); self.occupied |= occupied_bit(slot); } - pub(crate) fn remove_entry(&mut self, when: u64, item: &Item) { - let slot = slot_for(when, self.level); - - self.slot[slot].remove(item); + pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) { + let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); + unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { // The bit is currently set debug_assert!(self.occupied & occupied_bit(slot) != 0); @@ -198,17 +223,10 @@ impl Level { } } - pub(crate) fn pop_entry_slot(&mut self, slot: usize) -> Option<OwnedItem> { - let ret = self.slot[slot].pop(); - - if ret.is_some() && self.slot[slot].is_empty() { - // The bit is currently set - debug_assert!(self.occupied & occupied_bit(slot) != 0); - - self.occupied ^= occupied_bit(slot); - } + pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { + self.occupied &= !occupied_bit(slot); - ret + std::mem::take(&mut self.slot[slot]) } } diff --git a/src/time/wheel/mod.rs b/src/time/driver/wheel/mod.rs index 85ed2f1..164cac4 100644 --- a/src/time/wheel/mod.rs +++ b/src/time/driver/wheel/mod.rs @@ -1,17 +1,13 @@ -use crate::time::{driver::Entry, error::InsertError}; +use crate::time::driver::{TimerHandle, TimerShared}; +use crate::time::error::InsertError; mod level; pub(crate) use self::level::Expiration; use self::level::Level; -mod stack; -pub(crate) use self::stack::Stack; +use std::ptr::NonNull; -use std::sync::Arc; -use std::usize; - -pub(super) type Item = Entry; -pub(super) type OwnedItem = Arc<Item>; +use super::EntryList; /// Timing wheel implementation. /// @@ -40,6 +36,9 @@ pub(crate) struct Wheel { /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Vec<Level>, + + /// Entries queued for firing + pending: EntryList, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots @@ -48,14 +47,18 @@ pub(crate) struct Wheel { const NUM_LEVELS: usize = 6; /// The maximum duration of a `Sleep` -const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; +pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; impl Wheel { /// Create a new timing wheel pub(crate) fn new() -> Wheel { let levels = (0..NUM_LEVELS).map(Level::new).collect(); - Wheel { elapsed: 0, levels } + Wheel { + elapsed: 0, + levels, + pending: EntryList::new(), + } } /// Return the number of milliseconds that have elapsed since the timing @@ -68,14 +71,8 @@ impl Wheel { /// /// # Arguments /// - /// * `when`: is the instant at which the entry should be fired. It is - /// represented as the number of milliseconds since the creation - /// of the timing wheel. - /// /// * `item`: The item to insert into the wheel. /// - /// * `store`: The slab or `()` when using heap storage. - /// /// # Return /// /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. @@ -85,21 +82,28 @@ impl Wheel { /// immediately. /// /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. - pub(crate) fn insert( + /// + /// # Safety + /// + /// This function registers item into an intrusive linked list. The caller + /// must ensure that `item` is pinned and will not be dropped without first + /// being deregistered. + pub(crate) unsafe fn insert( &mut self, - when: u64, - item: OwnedItem, - ) -> Result<(), (OwnedItem, InsertError)> { + item: TimerHandle, + ) -> Result<u64, (TimerHandle, InsertError)> { + let when = item.sync_when(); + if when <= self.elapsed { return Err((item, InsertError::Elapsed)); - } else if when - self.elapsed > MAX_DURATION { - return Err((item, InsertError::Invalid)); } // Get the level at which the entry should be stored let level = self.level_for(when); - self.levels[level].add_entry(when, item); + unsafe { + self.levels[level].add_entry(item); + } debug_assert!({ self.levels[level] @@ -108,15 +112,21 @@ impl Wheel { .unwrap_or(true) }); - Ok(()) + Ok(when) } - /// Remove `item` from thee timing wheel. - pub(crate) fn remove(&mut self, item: &Item) { - let when = item.when(); - let level = self.level_for(when); + /// Remove `item` from the timing wheel. + pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) { + unsafe { + let when = item.as_ref().cached_when(); + if when == u64::max_value() { + self.pending.remove(item); + } else { + let level = self.level_for(when); - self.levels[level].remove_entry(when, item); + self.levels[level].remove_entry(item); + } + } } /// Instant at which to poll @@ -125,8 +135,12 @@ impl Wheel { } /// Advances the timer up to the instant represented by `now`. - pub(crate) fn poll(&mut self, now: u64) -> Option<OwnedItem> { + pub(crate) fn poll(&mut self, now: u64) -> Option<TimerHandle> { loop { + if let Some(handle) = self.pending.pop_back() { + return Some(handle); + } + // under what circumstances is poll.expiration Some vs. None? let expiration = self.next_expiration().and_then(|expiration| { if expiration.deadline > now { @@ -137,10 +151,9 @@ impl Wheel { }); match expiration { + Some(ref expiration) if expiration.deadline > now => return None, Some(ref expiration) => { - if let Some(item) = self.poll_expiration(expiration) { - return Some(item); - } + self.process_expiration(expiration); self.set_elapsed(expiration.deadline); } @@ -150,14 +163,25 @@ impl Wheel { // the current list of timers. advance to the poll's // current time and do nothing else. self.set_elapsed(now); - return None; + break; } } } + + self.pending.pop_back() } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option<Expiration> { + if !self.pending.is_empty() { + // Expire immediately as we have things pending firing + return Some(Expiration { + level: 0, + slot: 0, + deadline: self.elapsed, + }); + } + // Check all levels for level in 0..NUM_LEVELS { if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { @@ -172,6 +196,12 @@ impl Wheel { None } + /// Returns the tick at which this timer wheel next needs to perform some + /// processing, or None if there are no timers registered. + pub(super) fn next_expiration_time(&self) -> Option<u64> { + self.next_expiration().map(|ex| ex.deadline) + } + /// Used for debug assertions fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { let mut res = true; @@ -189,24 +219,41 @@ impl Wheel { /// iteratively find entries that are between the wheel's current /// time and the expiration time. for each in that population either - /// return it for notification (in the case of the last level) or tier + /// queue it for notification (in the case of the last level) or tier /// it down to the next level (in all other cases). - pub(crate) fn poll_expiration(&mut self, expiration: &Expiration) -> Option<OwnedItem> { - while let Some(item) = self.pop_entry(expiration) { + pub(crate) fn process_expiration(&mut self, expiration: &Expiration) { + // Note that we need to take _all_ of the entries off the list before + // processing any of them. This is important because it's possible that + // those entries might need to be reinserted into the same slot. + // + // This happens only on the highest level, when an entry is inserted + // more than MAX_DURATION into the future. When this happens, we wrap + // around, and process some entries a multiple of MAX_DURATION before + // they actually need to be dropped down a level. We then reinsert them + // back into the same position; we must make sure we don't then process + // those entries again or we'll end up in an infinite loop. + let mut entries = self.take_entries(expiration); + + while let Some(item) = entries.pop_back() { if expiration.level == 0 { - debug_assert_eq!(item.when(), expiration.deadline); - - return Some(item); - } else { - let when = item.when(); - - let next_level = expiration.level - 1; + debug_assert_eq!(unsafe { item.cached_when() }, expiration.deadline); + } - self.levels[next_level].add_entry(when, item); + // Try to expire the entry; this is cheap (doesn't synchronize) if + // the timer is not expired, and updates cached_when. + match unsafe { item.mark_pending(expiration.deadline) } { + Ok(()) => { + // Item was expired + self.pending.push_front(item); + } + Err(expiration_tick) => { + let level = level_for(expiration.deadline, expiration_tick); + unsafe { + self.levels[level].add_entry(item); + } + } } } - - None } fn set_elapsed(&mut self, when: u64) { @@ -222,8 +269,10 @@ impl Wheel { } } - fn pop_entry(&mut self, expiration: &Expiration) -> Option<OwnedItem> { - self.levels[expiration.level].pop_entry_slot(expiration.slot) + /// Obtains the list of entries that need processing for the given expiration. + /// + fn take_entries(&mut self, expiration: &Expiration) -> EntryList { + self.levels[expiration.level].take_slot(expiration.slot) } fn level_for(&self, when: u64) -> usize { @@ -232,12 +281,18 @@ impl Wheel { } fn level_for(elapsed: u64, when: u64) -> usize { - let masked = elapsed ^ when; + let mut masked = elapsed ^ when; + + if masked >= MAX_DURATION { + // Fudge the timer into the top level + masked = MAX_DURATION - 1; + } assert!(masked != 0, "elapsed={}; when={}", elapsed, when); let leading_zeros = masked.leading_zeros() as usize; let significant = 63 - leading_zeros; + significant / 6 } diff --git a/src/time/wheel/stack.rs b/src/time/driver/wheel/stack.rs index e7ed137..e7ed137 100644 --- a/src/time/wheel/stack.rs +++ b/src/time/driver/wheel/stack.rs diff --git a/src/time/error.rs b/src/time/error.rs index 24395c4..8674feb 100644 --- a/src/time/error.rs +++ b/src/time/error.rs @@ -23,17 +23,23 @@ use std::fmt; /// way to do this would be dropping the future that issued the timer operation. /// /// [shed load]: https://en.wikipedia.org/wiki/Load_Shedding -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] pub struct Error(Kind); -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[repr(u8)] -enum Kind { +pub(crate) enum Kind { Shutdown = 1, AtCapacity = 2, Invalid = 3, } +impl From<Kind> for Error { + fn from(k: Kind) -> Self { + Error(k) + } +} + /// Error returned by `Timeout`. #[derive(Debug, PartialEq)] pub struct Elapsed(()); @@ -41,7 +47,6 @@ pub struct Elapsed(()); #[derive(Debug)] pub(crate) enum InsertError { Elapsed, - Invalid, } // ===== impl Error ===== @@ -76,19 +81,6 @@ impl Error { pub fn is_invalid(&self) -> bool { matches!(self.0, Kind::Invalid) } - - pub(crate) fn as_u8(&self) -> u8 { - self.0 as u8 - } - - pub(crate) fn from_u8(n: u8) -> Self { - Error(match n { - 1 => Shutdown, - 2 => AtCapacity, - 3 => Invalid, - _ => panic!("u8 does not correspond to any time error variant"), - }) - } } impl error::Error for Error {} diff --git a/src/time/interval.rs b/src/time/interval.rs index c7c58e1..be93ba1 100644 --- a/src/time/interval.rs +++ b/src/time/interval.rs @@ -101,43 +101,22 @@ pub fn interval_at(start: Instant, period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); Interval { - delay: sleep_until(start), + delay: Box::pin(sleep_until(start)), period, } } /// Stream returned by [`interval`](interval) and [`interval_at`](interval_at). -/// -/// This type only implements the [`Stream`] trait if the "stream" feature is -/// enabled. -/// -/// [`Stream`]: trait@crate::stream::Stream #[derive(Debug)] pub struct Interval { /// Future that completes the next time the `Interval` yields a value. - delay: Sleep, + delay: Pin<Box<Sleep>>, /// The duration between values yielded by `Interval`. period: Duration, } impl Interval { - fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { - // Wait for the delay to be done - ready!(Pin::new(&mut self.delay).poll(cx)); - - // Get the `now` by looking at the `delay` deadline - let now = self.delay.deadline(); - - // The next interval value is `duration` after the one that just - // yielded. - let next = now + self.period; - self.delay.reset(next); - - // Return the current instant - Poll::Ready(now) - } - /// Completes when the next instant in the interval has been reached. /// /// # Examples @@ -161,13 +140,31 @@ impl Interval { pub async fn tick(&mut self) -> Instant { poll_fn(|cx| self.poll_tick(cx)).await } -} -#[cfg(feature = "stream")] -impl crate::stream::Stream for Interval { - type Item = Instant; + /// Poll for the next instant in the interval to be reached. + /// + /// This method can return the following values: + /// + /// * `Poll::Pending` if the next instant has not yet been reached. + /// * `Poll::Ready(instant)` if the next instant has been reached. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when the instant has elapsed. Note that on multiple + /// calls to `poll_tick`, only the `Waker` from the `Context` passed to the + /// most recent call is scheduled to receive a wakeup. + pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { + // Wait for the delay to be done + ready!(Pin::new(&mut self.delay).poll(cx)); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Instant>> { - Poll::Ready(Some(ready!(self.poll_tick(cx)))) + // Get the `now` by looking at the `delay` deadline + let now = self.delay.deadline(); + + // The next interval value is `duration` after the one that just + // yielded. + let next = now + self.period; + self.delay.as_mut().reset(next); + + // Return the current instant + Poll::Ready(now) } } diff --git a/src/time/mod.rs b/src/time/mod.rs index 29af717..8aaf9c1 100644 --- a/src/time/mod.rs +++ b/src/time/mod.rs @@ -36,9 +36,7 @@ //! } //! ``` //! -//! Require that an operation takes no more than 300ms. Note that this uses the -//! `timeout` function on the `FutureExt` trait. This trait is included in the -//! prelude. +//! Require that an operation takes no more than 300ms. //! //! ``` //! use tokio::time::{timeout, Duration}; @@ -77,9 +75,11 @@ //! //! #[tokio::main] //! async fn main() { -//! let mut interval = time::interval(time::Duration::from_secs(2)); +//! let interval = time::interval(time::Duration::from_secs(2)); +//! tokio::pin!(interval); +//! //! for _i in 0..5 { -//! interval.tick().await; +//! interval.as_mut().tick().await; //! task_that_takes_a_second().await; //! } //! } @@ -93,11 +93,11 @@ pub(crate) use self::clock::Clock; #[cfg(feature = "test-util")] pub use clock::{advance, pause, resume}; -mod sleep; -pub use sleep::{sleep, sleep_until, Sleep}; - pub(crate) mod driver; +#[doc(inline)] +pub use driver::sleep::{sleep, sleep_until, Sleep}; + pub mod error; mod instant; @@ -110,8 +110,6 @@ mod timeout; #[doc(inline)] pub use timeout::{timeout, timeout_at, Timeout}; -mod wheel; - #[cfg(test)] #[cfg(not(loom))] mod tests; @@ -119,32 +117,3 @@ mod tests; // Re-export for convenience #[doc(no_inline)] pub use std::time::Duration; - -// ===== Internal utils ===== - -enum Round { - Up, - Down, -} - -/// Convert a `Duration` to milliseconds, rounding up and saturating at -/// `u64::MAX`. -/// -/// The saturating is fine because `u64::MAX` milliseconds are still many -/// million years. -#[inline] -fn ms(duration: Duration, round: Round) -> u64 { - const NANOS_PER_MILLI: u32 = 1_000_000; - const MILLIS_PER_SEC: u64 = 1_000; - - // Round up. - let millis = match round { - Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, - Round::Down => duration.subsec_millis(), - }; - - duration - .as_secs() - .saturating_mul(MILLIS_PER_SEC) - .saturating_add(u64::from(millis)) -} diff --git a/src/time/tests/mod.rs b/src/time/tests/mod.rs index fae67da..35e1060 100644 --- a/src/time/tests/mod.rs +++ b/src/time/tests/mod.rs @@ -8,7 +8,7 @@ fn assert_sync<T: Sync>() {} #[test] fn registration_is_send_and_sync() { - use crate::time::sleep::Sleep; + use crate::time::Sleep; assert_send::<Sleep>(); assert_sync::<Sleep>(); diff --git a/src/time/tests/test_sleep.rs b/src/time/tests/test_sleep.rs index c8d931a..77ca07e 100644 --- a/src/time/tests/test_sleep.rs +++ b/src/time/tests/test_sleep.rs @@ -1,13 +1,6 @@ -use crate::park::{Park, Unpark}; -use crate::time::driver::{Driver, Entry, Handle}; -use crate::time::Clock; -use crate::time::{Duration, Instant}; - -use tokio_test::task; -use tokio_test::{assert_ok, assert_pending, assert_ready_ok}; - -use std::sync::Arc; +//use crate::time::driver::{Driver, Entry, Handle}; +/* macro_rules! poll { ($e:expr) => { $e.enter(|cx, e| e.poll_elapsed(cx)) @@ -447,3 +440,4 @@ impl Unpark for MockUnpark { fn ms(n: u64) -> Duration { Duration::from_millis(n) } +*/ diff --git a/src/time/timeout.rs b/src/time/timeout.rs index cf09b07..9d15a72 100644 --- a/src/time/timeout.rs +++ b/src/time/timeout.rs @@ -49,7 +49,7 @@ pub fn timeout<T>(duration: Duration, future: T) -> Timeout<T> where T: Future, { - let delay = Sleep::new_timeout(Instant::now() + duration, duration); + let delay = Sleep::new_timeout(Instant::now() + duration); Timeout::new_with_delay(future, delay) } diff --git a/src/util/mod.rs b/src/util/mod.rs index b2043dd..382bbb9 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -10,10 +10,11 @@ cfg_io_driver! { feature = "rt", feature = "sync", feature = "signal", + feature = "time", ))] pub(crate) mod linked_list; -#[cfg(any(feature = "rt-multi-thread", feature = "macros", feature = "stream"))] +#[cfg(any(feature = "rt-multi-thread", feature = "macros"))] mod rand; cfg_rt! { @@ -31,6 +32,6 @@ cfg_rt_multi_thread! { pub(crate) mod trace; -#[cfg(any(feature = "macros", feature = "stream"))] +#[cfg(any(feature = "macros"))] #[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] pub use rand::thread_rng_n; diff --git a/src/util/rand.rs b/src/util/rand.rs index 4b72b4b..5660103 100644 --- a/src/util/rand.rs +++ b/src/util/rand.rs @@ -52,7 +52,7 @@ impl FastRand { } // Used by the select macro and `StreamMap` -#[cfg(any(feature = "macros", feature = "stream"))] +#[cfg(any(feature = "macros"))] #[doc(hidden)] #[cfg_attr(not(feature = "macros"), allow(unreachable_pub))] pub fn thread_rng_n(n: u32) -> u32 { diff --git a/tests/async_send_sync.rs b/tests/async_send_sync.rs index 2ee3857..671fa4a 100644 --- a/tests/async_send_sync.rs +++ b/tests/async_send_sync.rs @@ -14,8 +14,7 @@ type BoxFutureSync<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + type BoxFutureSend<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send>>; #[allow(dead_code)] type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T>>>; -#[allow(dead_code)] -type BoxStream<T> = std::pin::Pin<Box<dyn tokio::stream::Stream<Item = T>>>; + #[allow(dead_code)] type BoxAsyncRead = std::pin::Pin<Box<dyn tokio::io::AsyncBufRead>>; #[allow(dead_code)] @@ -94,6 +93,14 @@ macro_rules! assert_value { AmbiguousIfSync::some_item(&f); }; }; + ($type:ty: Unpin) => { + #[allow(unreachable_code)] + #[allow(unused_variables)] + const _: fn() = || { + let f: $type = todo!(); + require_unpin(&f); + }; + }; } macro_rules! async_assert_fn { ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & Sync) => { @@ -222,10 +229,6 @@ async_assert_fn!(tokio::signal::ctrl_c(): Send & Sync); #[cfg(unix)] async_assert_fn!(tokio::signal::unix::Signal::recv(_): Send & Sync); -async_assert_fn!(tokio::stream::empty<Rc<u8>>(): Send & Sync); -async_assert_fn!(tokio::stream::pending<Rc<u8>>(): Send & Sync); -async_assert_fn!(tokio::stream::iter(std::vec::IntoIter<u8>): Send & Sync); - async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex<u8>::lock(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex<Cell<u8>>::lock(_): Send & Sync); @@ -285,13 +288,12 @@ async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sy async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync); async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync); -async_assert_fn!(tokio::stream::StreamExt::next(&mut BoxStream<()>): !Unpin); -async_assert_fn!(tokio::stream::StreamExt::try_next(&mut BoxStream<Result<(), ()>>): !Unpin); -async_assert_fn!(tokio::stream::StreamExt::all(&mut BoxStream<()>, fn(())->bool): !Unpin); -async_assert_fn!(tokio::stream::StreamExt::any(&mut BoxStream<()>, fn(())->bool): !Unpin); -async_assert_fn!(tokio::stream::StreamExt::fold(&mut BoxStream<()>, (), fn((), ())->()): !Unpin); -async_assert_fn!(tokio::stream::StreamExt::collect<Vec<()>>(&mut BoxStream<()>): !Unpin); - +assert_value!(tokio::time::Interval: Unpin); +async_assert_fn!(tokio::time::sleep(Duration): !Unpin); +async_assert_fn!(tokio::time::sleep_until(Instant): !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Unpin); +async_assert_fn!(tokio::time::Interval::tick(_): !Unpin); async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec<u8>): !Unpin); async_assert_fn!(tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): !Unpin); async_assert_fn!(tokio::io::AsyncReadExt::read(&mut BoxAsyncRead, &mut [u8]): !Unpin); diff --git a/tests/buffered.rs b/tests/buffered.rs index 97ba00c..98b6d5f 100644 --- a/tests/buffered.rs +++ b/tests/buffered.rs @@ -2,7 +2,6 @@ #![cfg(feature = "full")] use tokio::net::TcpListener; -use tokio::prelude::*; use tokio_test::assert_ok; use std::io::prelude::*; @@ -41,7 +40,7 @@ async fn echo_server() { let (mut a, _) = assert_ok!(srv.accept().await); let (mut b, _) = assert_ok!(srv.accept().await); - let n = assert_ok!(io::copy(&mut a, &mut b).await); + let n = assert_ok!(tokio::io::copy(&mut a, &mut b).await); let (expected, t2) = t.join().unwrap(); let actual = t2.join().unwrap(); diff --git a/tests/fs_dir.rs b/tests/fs_dir.rs index 6355ef0..21efe8c 100644 --- a/tests/fs_dir.rs +++ b/tests/fs_dir.rs @@ -85,35 +85,3 @@ async fn read_inherent() { vec!["aa".to_string(), "bb".to_string(), "cc".to_string()] ); } - -#[tokio::test] -async fn read_stream() { - use tokio::stream::StreamExt; - - let base_dir = tempdir().unwrap(); - - let p = base_dir.path(); - std::fs::create_dir(p.join("aa")).unwrap(); - std::fs::create_dir(p.join("bb")).unwrap(); - std::fs::create_dir(p.join("cc")).unwrap(); - - let files = Arc::new(Mutex::new(Vec::new())); - - let f = files.clone(); - let p = p.to_path_buf(); - - let mut entries = fs::read_dir(p).await.unwrap(); - - while let Some(res) = entries.next().await { - let e = assert_ok!(res); - let s = e.file_name().to_str().unwrap().to_string(); - f.lock().unwrap().push(s); - } - - let mut files = files.lock().unwrap(); - files.sort(); // because the order is not guaranteed - assert_eq!( - *files, - vec!["aa".to_string(), "bb".to_string(), "cc".to_string()] - ); -} diff --git a/tests/fs_file.rs b/tests/fs_file.rs index d5b56e6..bf2f1d7 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -2,7 +2,7 @@ #![cfg(feature = "full")] use tokio::fs::File; -use tokio::prelude::*; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio_test::task; use std::io::prelude::*; diff --git a/tests/fs_file_mocked.rs b/tests/fs_file_mocked.rs index edb74a7..7771532 100644 --- a/tests/fs_file_mocked.rs +++ b/tests/fs_file_mocked.rs @@ -62,7 +62,7 @@ pub(crate) mod sync { } use fs::sys; -use tokio::prelude::*; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; use std::io::SeekFrom; diff --git a/tests/fs_link.rs b/tests/fs_link.rs index cbbe27e..2ef666f 100644 --- a/tests/fs_link.rs +++ b/tests/fs_link.rs @@ -48,9 +48,7 @@ async fn test_symlink() { let src_2 = src.clone(); let dst_2 = dst.clone(); - assert!(fs::os::unix::symlink(src_2.clone(), dst_2.clone()) - .await - .is_ok()); + assert!(fs::symlink(src_2.clone(), dst_2.clone()).await.is_ok()); let mut content = String::new(); diff --git a/tests/io_async_fd.rs b/tests/io_async_fd.rs index f8dc65f..d1586bb 100644 --- a/tests/io_async_fd.rs +++ b/tests/io_async_fd.rs @@ -201,7 +201,10 @@ async fn reset_readable() { let mut guard = readable.await.unwrap(); - guard.with_io(|| afd_a.get_ref().read(&mut [0])).unwrap(); + guard + .try_io(|_| afd_a.get_ref().read(&mut [0])) + .unwrap() + .unwrap(); // `a` is not readable, but the reactor still thinks it is // (because we have not observed a not-ready error yet) @@ -233,12 +236,10 @@ async fn reset_writable() { let mut guard = afd_a.writable().await.unwrap(); // Write until we get a WouldBlock. This also clears the ready state. - loop { - if let Err(e) = guard.with_io(|| afd_a.get_ref().write(&[0; 512][..])) { - assert_eq!(ErrorKind::WouldBlock, e.kind()); - break; - } - } + while guard + .try_io(|_| afd_a.get_ref().write(&[0; 512][..])) + .is_ok() + {} // Writable state should be cleared now. let writable = afd_a.writable(); @@ -313,9 +314,7 @@ async fn reregister() { } #[tokio::test] -async fn with_poll() { - use std::task::Poll; - +async fn try_io() { let (a, mut b) = socketpair(); b.write_all(b"0").unwrap(); @@ -327,13 +326,13 @@ async fn with_poll() { afd_a.get_ref().read_exact(&mut [0]).unwrap(); // Should not clear the readable state - let _ = guard.with_poll(|| Poll::Ready(())); + let _ = guard.try_io(|_| Ok(())); // Still readable... let _ = afd_a.readable().await.unwrap(); // Should clear the readable state - let _ = guard.with_poll(|| Poll::Pending::<()>); + let _ = guard.try_io(|_| io::Result::<()>::Err(ErrorKind::WouldBlock.into())); // Assert not readable let readable = afd_a.readable(); diff --git a/tests/io_lines.rs b/tests/io_lines.rs index 2f6b339..9996d81 100644 --- a/tests/io_lines.rs +++ b/tests/io_lines.rs @@ -17,19 +17,3 @@ async fn lines_inherent() { assert_eq!(b, ""); assert!(assert_ok!(st.next_line().await).is_none()); } - -#[tokio::test] -async fn lines_stream() { - use tokio::stream::StreamExt; - - let rd: &[u8] = b"hello\r\nworld\n\n"; - let mut st = rd.lines(); - - let b = assert_ok!(st.next().await.unwrap()); - assert_eq!(b, "hello"); - let b = assert_ok!(st.next().await.unwrap()); - assert_eq!(b, "world"); - let b = assert_ok!(st.next().await.unwrap()); - assert_eq!(b, ""); - assert!(st.next().await.is_none()); -} diff --git a/tests/macros_select.rs b/tests/macros_select.rs index cc214bb..3359849 100644 --- a/tests/macros_select.rs +++ b/tests/macros_select.rs @@ -359,12 +359,14 @@ async fn join_with_select() { async fn use_future_in_if_condition() { use tokio::time::{self, Duration}; - let mut sleep = time::sleep(Duration::from_millis(50)); + let sleep = time::sleep(Duration::from_millis(50)); + tokio::pin!(sleep); tokio::select! { - _ = &mut sleep, if !sleep.is_elapsed() => { + _ = time::sleep(Duration::from_millis(50)), if false => { + panic!("if condition ignored") } - _ = async { 1 } => { + _ = async { 1u32 } => { } } } diff --git a/tests/process_issue_2174.rs b/tests/process_issue_2174.rs index 6ee7d1a..5ee9dc0 100644 --- a/tests/process_issue_2174.rs +++ b/tests/process_issue_2174.rs @@ -11,7 +11,7 @@ use std::process::Stdio; use std::time::Duration; -use tokio::prelude::*; +use tokio::io::AsyncWriteExt; use tokio::process::Command; use tokio::time; use tokio_test::assert_err; diff --git a/tests/process_kill_on_drop.rs b/tests/process_kill_on_drop.rs index f67bb23..00f5c6d 100644 --- a/tests/process_kill_on_drop.rs +++ b/tests/process_kill_on_drop.rs @@ -10,7 +10,7 @@ use tokio_test::assert_ok; #[tokio::test] async fn kill_on_drop() { - let mut cmd = Command::new("sh"); + let mut cmd = Command::new("bash"); cmd.args(&[ "-c", " diff --git a/tests/rt_basic.rs b/tests/rt_basic.rs index 7b5b622..977a838 100644 --- a/tests/rt_basic.rs +++ b/tests/rt_basic.rs @@ -2,12 +2,16 @@ #![cfg(feature = "full")] use tokio::runtime::Runtime; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; use tokio_test::{assert_err, assert_ok}; use std::thread; use std::time::Duration; +mod support { + pub(crate) mod mpsc_stream; +} + #[test] fn spawned_task_does_not_progress_without_block_on() { let (tx, mut rx) = oneshot::channel(); @@ -36,7 +40,7 @@ fn no_extra_poll() { Arc, }; use std::task::{Context, Poll}; - use tokio::stream::{Stream, StreamExt}; + use tokio_stream::{Stream, StreamExt}; pin_project! { struct TrackPolls<S> { @@ -58,8 +62,8 @@ fn no_extra_poll() { } } - let (tx, rx) = mpsc::unbounded_channel(); - let mut rx = TrackPolls { + let (tx, rx) = support::mpsc_stream::unbounded_channel_stream::<()>(); + let rx = TrackPolls { npolls: Arc::new(AtomicUsize::new(0)), s: rx, }; @@ -67,6 +71,9 @@ fn no_extra_poll() { let rt = rt(); + // TODO: could probably avoid this, but why not. + let mut rx = Box::pin(rx); + rt.spawn(async move { while rx.next().await.is_some() {} }); rt.block_on(async { tokio::task::yield_now().await; diff --git a/tests/rt_common.rs b/tests/rt_common.rs index 74a94d5..66e6f2c 100644 --- a/tests/rt_common.rs +++ b/tests/rt_common.rs @@ -56,7 +56,7 @@ fn send_sync_bound() { rt_test! { use tokio::net::{TcpListener, TcpStream, UdpSocket}; - use tokio::prelude::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::runtime::Runtime; use tokio::sync::oneshot; use tokio::{task, time}; @@ -858,6 +858,21 @@ rt_test! { } #[test] + fn shutdown_timeout_0() { + let runtime = rt(); + + runtime.block_on(async move { + task::spawn_blocking(move || { + thread::sleep(Duration::from_secs(10_000)); + }); + }); + + let now = Instant::now(); + Arc::try_unwrap(runtime).unwrap().shutdown_timeout(Duration::from_nanos(0)); + assert!(now.elapsed().as_secs() < 1); + } + + #[test] fn shutdown_wakeup_time() { let runtime = rt(); diff --git a/tests/rt_threaded.rs b/tests/rt_threaded.rs index 90ebf6a..19b381c 100644 --- a/tests/rt_threaded.rs +++ b/tests/rt_threaded.rs @@ -331,7 +331,7 @@ fn coop_and_block_in_place() { // runtime worker yielded as part of `block_in_place` and guarantees the // same thread will reclaim the worker at the end of the // `block_in_place` call. - .max_threads(1) + .max_blocking_threads(1) .build() .unwrap(); @@ -375,13 +375,36 @@ fn coop_and_block_in_place() { // Testing this does not panic #[test] -fn max_threads() { +fn max_blocking_threads() { let _rt = tokio::runtime::Builder::new_multi_thread() - .max_threads(1) + .max_blocking_threads(1) .build() .unwrap(); } +#[test] +#[should_panic] +fn max_blocking_threads_set_to_zero() { + let _rt = tokio::runtime::Builder::new_multi_thread() + .max_blocking_threads(0) + .build() + .unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn hang_on_shutdown() { + let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>(); + tokio::spawn(async move { + tokio::task::block_in_place(|| sync_rx.recv().ok()); + }); + + tokio::spawn(async { + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + drop(sync_tx); + }); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; +} + fn rt() -> Runtime { Runtime::new().unwrap() } diff --git a/tests/stream_chain.rs b/tests/stream_chain.rs deleted file mode 100644 index 98461a8..0000000 --- a/tests/stream_chain.rs +++ /dev/null @@ -1,95 +0,0 @@ -use tokio::stream::{self, Stream, StreamExt}; -use tokio::sync::mpsc; -use tokio_test::{assert_pending, assert_ready, task}; - -#[tokio::test] -async fn basic_usage() { - let one = stream::iter(vec![1, 2, 3]); - let two = stream::iter(vec![4, 5, 6]); - - let mut stream = one.chain(two); - - assert_eq!(stream.size_hint(), (6, Some(6))); - assert_eq!(stream.next().await, Some(1)); - - assert_eq!(stream.size_hint(), (5, Some(5))); - assert_eq!(stream.next().await, Some(2)); - - assert_eq!(stream.size_hint(), (4, Some(4))); - assert_eq!(stream.next().await, Some(3)); - - assert_eq!(stream.size_hint(), (3, Some(3))); - assert_eq!(stream.next().await, Some(4)); - - assert_eq!(stream.size_hint(), (2, Some(2))); - assert_eq!(stream.next().await, Some(5)); - - assert_eq!(stream.size_hint(), (1, Some(1))); - assert_eq!(stream.next().await, Some(6)); - - assert_eq!(stream.size_hint(), (0, Some(0))); - assert_eq!(stream.next().await, None); - - assert_eq!(stream.size_hint(), (0, Some(0))); - assert_eq!(stream.next().await, None); -} - -#[tokio::test] -async fn pending_first() { - let (tx1, rx1) = mpsc::unbounded_channel(); - let (tx2, rx2) = mpsc::unbounded_channel(); - - let mut stream = task::spawn(rx1.chain(rx2)); - assert_eq!(stream.size_hint(), (0, None)); - - assert_pending!(stream.poll_next()); - - tx2.send(2).unwrap(); - assert!(!stream.is_woken()); - - assert_pending!(stream.poll_next()); - - tx1.send(1).unwrap(); - assert!(stream.is_woken()); - assert_eq!(Some(1), assert_ready!(stream.poll_next())); - - assert_pending!(stream.poll_next()); - - drop(tx1); - - assert_eq!(stream.size_hint(), (0, None)); - - assert!(stream.is_woken()); - assert_eq!(Some(2), assert_ready!(stream.poll_next())); - - assert_eq!(stream.size_hint(), (0, None)); - - drop(tx2); - - assert_eq!(stream.size_hint(), (0, None)); - assert_eq!(None, assert_ready!(stream.poll_next())); -} - -#[test] -fn size_overflow() { - struct Monster; - - impl tokio::stream::Stream for Monster { - type Item = (); - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Option<()>> { - panic!() - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (usize::max_value(), Some(usize::max_value())) - } - } - - let m1 = Monster; - let m2 = Monster; - let m = m1.chain(m2); - assert_eq!(m.size_hint(), (usize::max_value(), None)); -} diff --git a/tests/stream_collect.rs b/tests/stream_collect.rs deleted file mode 100644 index 7ab1a34..0000000 --- a/tests/stream_collect.rs +++ /dev/null @@ -1,137 +0,0 @@ -use tokio::stream::{self, StreamExt}; -use tokio::sync::mpsc; -use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; - -#[allow(clippy::let_unit_value)] -#[tokio::test] -async fn empty_unit() { - // Drains the stream. - let mut iter = vec![(), (), ()].into_iter(); - let _: () = stream::iter(&mut iter).collect().await; - assert!(iter.next().is_none()); -} - -#[tokio::test] -async fn empty_vec() { - let coll: Vec<u32> = stream::empty().collect().await; - assert!(coll.is_empty()); -} - -#[tokio::test] -async fn empty_box_slice() { - let coll: Box<[u32]> = stream::empty().collect().await; - assert!(coll.is_empty()); -} - -#[tokio::test] -async fn empty_string() { - let coll: String = stream::empty::<&str>().collect().await; - assert!(coll.is_empty()); -} - -#[tokio::test] -async fn empty_result() { - let coll: Result<Vec<u32>, &str> = stream::empty().collect().await; - assert_eq!(Ok(vec![]), coll); -} - -#[tokio::test] -async fn collect_vec_items() { - let (tx, rx) = mpsc::unbounded_channel(); - let mut fut = task::spawn(rx.collect::<Vec<i32>>()); - - assert_pending!(fut.poll()); - - tx.send(1).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - tx.send(2).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - drop(tx); - assert!(fut.is_woken()); - let coll = assert_ready!(fut.poll()); - assert_eq!(vec![1, 2], coll); -} - -#[tokio::test] -async fn collect_string_items() { - let (tx, rx) = mpsc::unbounded_channel(); - let mut fut = task::spawn(rx.collect::<String>()); - - assert_pending!(fut.poll()); - - tx.send("hello ".to_string()).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - tx.send("world".to_string()).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - drop(tx); - assert!(fut.is_woken()); - let coll = assert_ready!(fut.poll()); - assert_eq!("hello world", coll); -} - -#[tokio::test] -async fn collect_str_items() { - let (tx, rx) = mpsc::unbounded_channel(); - let mut fut = task::spawn(rx.collect::<String>()); - - assert_pending!(fut.poll()); - - tx.send("hello ").unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - tx.send("world").unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - drop(tx); - assert!(fut.is_woken()); - let coll = assert_ready!(fut.poll()); - assert_eq!("hello world", coll); -} - -#[tokio::test] -async fn collect_results_ok() { - let (tx, rx) = mpsc::unbounded_channel(); - let mut fut = task::spawn(rx.collect::<Result<String, &str>>()); - - assert_pending!(fut.poll()); - - tx.send(Ok("hello ")).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - tx.send(Ok("world")).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - drop(tx); - assert!(fut.is_woken()); - let coll = assert_ready_ok!(fut.poll()); - assert_eq!("hello world", coll); -} - -#[tokio::test] -async fn collect_results_err() { - let (tx, rx) = mpsc::unbounded_channel(); - let mut fut = task::spawn(rx.collect::<Result<String, &str>>()); - - assert_pending!(fut.poll()); - - tx.send(Ok("hello ")).unwrap(); - assert!(fut.is_woken()); - assert_pending!(fut.poll()); - - tx.send(Err("oh no")).unwrap(); - assert!(fut.is_woken()); - let err = assert_ready_err!(fut.poll()); - assert_eq!("oh no", err); -} diff --git a/tests/stream_empty.rs b/tests/stream_empty.rs deleted file mode 100644 index f278076..0000000 --- a/tests/stream_empty.rs +++ /dev/null @@ -1,11 +0,0 @@ -use tokio::stream::{self, Stream, StreamExt}; - -#[tokio::test] -async fn basic_usage() { - let mut stream = stream::empty::<i32>(); - - for _ in 0..2 { - assert_eq!(stream.size_hint(), (0, Some(0))); - assert_eq!(None, stream.next().await); - } -} diff --git a/tests/stream_fuse.rs b/tests/stream_fuse.rs deleted file mode 100644 index 9d7d969..0000000 --- a/tests/stream_fuse.rs +++ /dev/null @@ -1,50 +0,0 @@ -use tokio::stream::{Stream, StreamExt}; - -use std::pin::Pin; -use std::task::{Context, Poll}; - -// a stream which alternates between Some and None -struct Alternate { - state: i32, -} - -impl Stream for Alternate { - type Item = i32; - - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<i32>> { - let val = self.state; - self.state += 1; - - // if it's even, Some(i32), else None - if val % 2 == 0 { - Poll::Ready(Some(val)) - } else { - Poll::Ready(None) - } - } -} - -#[tokio::test] -async fn basic_usage() { - let mut stream = Alternate { state: 0 }; - - // the stream goes back and forth - assert_eq!(stream.next().await, Some(0)); - assert_eq!(stream.next().await, None); - assert_eq!(stream.next().await, Some(2)); - assert_eq!(stream.next().await, None); - - // however, once it is fused - let mut stream = stream.fuse(); - - assert_eq!(stream.size_hint(), (0, None)); - assert_eq!(stream.next().await, Some(4)); - - assert_eq!(stream.size_hint(), (0, None)); - assert_eq!(stream.next().await, None); - - // it will always return `None` after the first time. - assert_eq!(stream.size_hint(), (0, Some(0))); - assert_eq!(stream.next().await, None); - assert_eq!(stream.size_hint(), (0, Some(0))); -} diff --git a/tests/stream_iter.rs b/tests/stream_iter.rs deleted file mode 100644 index 45148a7..0000000 --- a/tests/stream_iter.rs +++ /dev/null @@ -1,18 +0,0 @@ -use tokio::stream; -use tokio_test::task; - -use std::iter; - -#[tokio::test] -async fn coop() { - let mut stream = task::spawn(stream::iter(iter::repeat(1))); - - for _ in 0..10_000 { - if stream.poll_next().is_pending() { - assert!(stream.is_woken()); - return; - } - } - - panic!("did not yield"); -} diff --git a/tests/stream_merge.rs b/tests/stream_merge.rs deleted file mode 100644 index 45ecdcb..0000000 --- a/tests/stream_merge.rs +++ /dev/null @@ -1,78 +0,0 @@ -use tokio::stream::{self, Stream, StreamExt}; -use tokio::sync::mpsc; -use tokio_test::task; -use tokio_test::{assert_pending, assert_ready}; - -#[tokio::test] -async fn merge_sync_streams() { - let mut s = stream::iter(vec![0, 2, 4, 6]).merge(stream::iter(vec![1, 3, 5])); - - for i in 0..7 { - let rem = 7 - i; - assert_eq!(s.size_hint(), (rem, Some(rem))); - assert_eq!(Some(i), s.next().await); - } - - assert!(s.next().await.is_none()); -} - -#[tokio::test] -async fn merge_async_streams() { - let (tx1, rx1) = mpsc::unbounded_channel(); - let (tx2, rx2) = mpsc::unbounded_channel(); - - let mut rx = task::spawn(rx1.merge(rx2)); - - assert_eq!(rx.size_hint(), (0, None)); - - assert_pending!(rx.poll_next()); - - tx1.send(1).unwrap(); - - assert!(rx.is_woken()); - assert_eq!(Some(1), assert_ready!(rx.poll_next())); - - assert_pending!(rx.poll_next()); - tx2.send(2).unwrap(); - - assert!(rx.is_woken()); - assert_eq!(Some(2), assert_ready!(rx.poll_next())); - assert_pending!(rx.poll_next()); - - drop(tx1); - assert!(rx.is_woken()); - assert_pending!(rx.poll_next()); - - tx2.send(3).unwrap(); - assert!(rx.is_woken()); - assert_eq!(Some(3), assert_ready!(rx.poll_next())); - assert_pending!(rx.poll_next()); - - drop(tx2); - assert!(rx.is_woken()); - assert_eq!(None, assert_ready!(rx.poll_next())); -} - -#[test] -fn size_overflow() { - struct Monster; - - impl tokio::stream::Stream for Monster { - type Item = (); - fn poll_next( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Option<()>> { - panic!() - } - - fn size_hint(&self) -> (usize, Option<usize>) { - (usize::max_value(), Some(usize::max_value())) - } - } - - let m1 = Monster; - let m2 = Monster; - let m = m1.merge(m2); - assert_eq!(m.size_hint(), (usize::max_value(), None)); -} diff --git a/tests/stream_once.rs b/tests/stream_once.rs deleted file mode 100644 index bb4635a..0000000 --- a/tests/stream_once.rs +++ /dev/null @@ -1,12 +0,0 @@ -use tokio::stream::{self, Stream, StreamExt}; - -#[tokio::test] -async fn basic_usage() { - let mut one = stream::once(1); - - assert_eq!(one.size_hint(), (1, Some(1))); - assert_eq!(Some(1), one.next().await); - - assert_eq!(one.size_hint(), (0, Some(0))); - assert_eq!(None, one.next().await); -} diff --git a/tests/stream_pending.rs b/tests/stream_pending.rs deleted file mode 100644 index f4d3080..0000000 --- a/tests/stream_pending.rs +++ /dev/null @@ -1,14 +0,0 @@ -use tokio::stream::{self, Stream, StreamExt}; -use tokio_test::{assert_pending, task}; - -#[tokio::test] -async fn basic_usage() { - let mut stream = stream::pending::<i32>(); - - for _ in 0..2 { - assert_eq!(stream.size_hint(), (0, None)); - - let mut next = task::spawn(async { stream.next().await }); - assert_pending!(next.poll()); - } -} diff --git a/tests/stream_stream_map.rs b/tests/stream_stream_map.rs deleted file mode 100644 index 38bb0c5..0000000 --- a/tests/stream_stream_map.rs +++ /dev/null @@ -1,372 +0,0 @@ -use tokio::stream::{self, pending, Stream, StreamExt, StreamMap}; -use tokio::sync::mpsc; -use tokio_test::{assert_ok, assert_pending, assert_ready, task}; - -use std::pin::Pin; - -macro_rules! assert_ready_some { - ($($t:tt)*) => { - match assert_ready!($($t)*) { - Some(v) => v, - None => panic!("expected `Some`, got `None`"), - } - }; -} - -macro_rules! assert_ready_none { - ($($t:tt)*) => { - match assert_ready!($($t)*) { - None => {} - Some(v) => panic!("expected `None`, got `Some({:?})`", v), - } - }; -} - -#[tokio::test] -async fn empty() { - let mut map = StreamMap::<&str, stream::Pending<()>>::new(); - - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - - assert!(map.next().await.is_none()); - assert!(map.next().await.is_none()); - - assert!(map.remove("foo").is_none()); -} - -#[tokio::test] -async fn single_entry() { - let mut map = task::spawn(StreamMap::new()); - let (tx, rx) = mpsc::unbounded_channel(); - - assert_ready_none!(map.poll_next()); - - assert!(map.insert("foo", rx).is_none()); - assert!(map.contains_key("foo")); - assert!(!map.contains_key("bar")); - - assert_eq!(map.len(), 1); - assert!(!map.is_empty()); - - assert_pending!(map.poll_next()); - - assert_ok!(tx.send(1)); - - assert!(map.is_woken()); - let (k, v) = assert_ready_some!(map.poll_next()); - assert_eq!(k, "foo"); - assert_eq!(v, 1); - - assert_pending!(map.poll_next()); - - assert_ok!(tx.send(2)); - - assert!(map.is_woken()); - let (k, v) = assert_ready_some!(map.poll_next()); - assert_eq!(k, "foo"); - assert_eq!(v, 2); - - assert_pending!(map.poll_next()); - drop(tx); - assert!(map.is_woken()); - assert_ready_none!(map.poll_next()); -} - -#[tokio::test] -async fn multiple_entries() { - let mut map = task::spawn(StreamMap::new()); - let (tx1, rx1) = mpsc::unbounded_channel(); - let (tx2, rx2) = mpsc::unbounded_channel(); - - map.insert("foo", rx1); - map.insert("bar", rx2); - - assert_pending!(map.poll_next()); - - assert_ok!(tx1.send(1)); - - assert!(map.is_woken()); - let (k, v) = assert_ready_some!(map.poll_next()); - assert_eq!(k, "foo"); - assert_eq!(v, 1); - - assert_pending!(map.poll_next()); - - assert_ok!(tx2.send(2)); - - assert!(map.is_woken()); - let (k, v) = assert_ready_some!(map.poll_next()); - assert_eq!(k, "bar"); - assert_eq!(v, 2); - - assert_pending!(map.poll_next()); - - assert_ok!(tx1.send(3)); - assert_ok!(tx2.send(4)); - - assert!(map.is_woken()); - - // Given the randomization, there is no guarantee what order the values will - // be received in. - let mut v = (0..2) - .map(|_| assert_ready_some!(map.poll_next())) - .collect::<Vec<_>>(); - - assert_pending!(map.poll_next()); - - v.sort_unstable(); - assert_eq!(v[0].0, "bar"); - assert_eq!(v[0].1, 4); - assert_eq!(v[1].0, "foo"); - assert_eq!(v[1].1, 3); - - drop(tx1); - assert!(map.is_woken()); - assert_pending!(map.poll_next()); - drop(tx2); - - assert_ready_none!(map.poll_next()); -} - -#[tokio::test] -async fn insert_remove() { - let mut map = task::spawn(StreamMap::new()); - let (tx, rx) = mpsc::unbounded_channel(); - - assert_ready_none!(map.poll_next()); - - assert!(map.insert("foo", rx).is_none()); - let rx = map.remove("foo").unwrap(); - - assert_ok!(tx.send(1)); - - assert!(!map.is_woken()); - assert_ready_none!(map.poll_next()); - - assert!(map.insert("bar", rx).is_none()); - - let v = assert_ready_some!(map.poll_next()); - assert_eq!(v.0, "bar"); - assert_eq!(v.1, 1); - - assert!(map.remove("bar").is_some()); - assert_ready_none!(map.poll_next()); - - assert!(map.is_empty()); - assert_eq!(0, map.len()); -} - -#[tokio::test] -async fn replace() { - let mut map = task::spawn(StreamMap::new()); - let (tx1, rx1) = mpsc::unbounded_channel(); - let (tx2, rx2) = mpsc::unbounded_channel(); - - assert!(map.insert("foo", rx1).is_none()); - - assert_pending!(map.poll_next()); - - let _rx1 = map.insert("foo", rx2).unwrap(); - - assert_pending!(map.poll_next()); - - tx1.send(1).unwrap(); - assert_pending!(map.poll_next()); - - tx2.send(2).unwrap(); - assert!(map.is_woken()); - let v = assert_ready_some!(map.poll_next()); - assert_eq!(v.0, "foo"); - assert_eq!(v.1, 2); -} - -#[test] -fn size_hint_with_upper() { - let mut map = StreamMap::new(); - - map.insert("a", stream::iter(vec![1])); - map.insert("b", stream::iter(vec![1, 2])); - map.insert("c", stream::iter(vec![1, 2, 3])); - - assert_eq!(3, map.len()); - assert!(!map.is_empty()); - - let size_hint = map.size_hint(); - assert_eq!(size_hint, (6, Some(6))); -} - -#[test] -fn size_hint_without_upper() { - let mut map = StreamMap::new(); - - map.insert("a", pin_box(stream::iter(vec![1]))); - map.insert("b", pin_box(stream::iter(vec![1, 2]))); - map.insert("c", pin_box(pending())); - - let size_hint = map.size_hint(); - assert_eq!(size_hint, (3, None)); -} - -#[test] -fn new_capacity_zero() { - let map = StreamMap::<&str, stream::Pending<()>>::new(); - assert_eq!(0, map.capacity()); - - assert!(map.keys().next().is_none()); -} - -#[test] -fn with_capacity() { - let map = StreamMap::<&str, stream::Pending<()>>::with_capacity(10); - assert!(10 <= map.capacity()); - - assert!(map.keys().next().is_none()); -} - -#[test] -fn iter_keys() { - let mut map = StreamMap::new(); - - map.insert("a", pending::<i32>()); - map.insert("b", pending()); - map.insert("c", pending()); - - let mut keys = map.keys().collect::<Vec<_>>(); - keys.sort_unstable(); - - assert_eq!(&keys[..], &[&"a", &"b", &"c"]); -} - -#[test] -fn iter_values() { - let mut map = StreamMap::new(); - - map.insert("a", stream::iter(vec![1])); - map.insert("b", stream::iter(vec![1, 2])); - map.insert("c", stream::iter(vec![1, 2, 3])); - - let mut size_hints = map.values().map(|s| s.size_hint().0).collect::<Vec<_>>(); - - size_hints.sort_unstable(); - - assert_eq!(&size_hints[..], &[1, 2, 3]); -} - -#[test] -fn iter_values_mut() { - let mut map = StreamMap::new(); - - map.insert("a", stream::iter(vec![1])); - map.insert("b", stream::iter(vec![1, 2])); - map.insert("c", stream::iter(vec![1, 2, 3])); - - let mut size_hints = map - .values_mut() - .map(|s: &mut _| s.size_hint().0) - .collect::<Vec<_>>(); - - size_hints.sort_unstable(); - - assert_eq!(&size_hints[..], &[1, 2, 3]); -} - -#[test] -fn clear() { - let mut map = task::spawn(StreamMap::new()); - - map.insert("a", stream::iter(vec![1])); - map.insert("b", stream::iter(vec![1, 2])); - map.insert("c", stream::iter(vec![1, 2, 3])); - - assert_ready_some!(map.poll_next()); - - map.clear(); - - assert_ready_none!(map.poll_next()); - assert!(map.is_empty()); -} - -#[test] -fn contains_key_borrow() { - let mut map = StreamMap::new(); - map.insert("foo".to_string(), pending::<()>()); - - assert!(map.contains_key("foo")); -} - -#[test] -fn one_ready_many_none() { - // Run a few times because of randomness - for _ in 0..100 { - let mut map = task::spawn(StreamMap::new()); - - map.insert(0, pin_box(stream::empty())); - map.insert(1, pin_box(stream::empty())); - map.insert(2, pin_box(stream::once("hello"))); - map.insert(3, pin_box(stream::pending())); - - let v = assert_ready_some!(map.poll_next()); - assert_eq!(v, (2, "hello")); - } -} - -proptest::proptest! { - #[test] - fn fuzz_pending_complete_mix(kinds: Vec<bool>) { - use std::task::{Context, Poll}; - - struct DidPoll<T> { - did_poll: bool, - inner: T, - } - - impl<T: Stream + Unpin> Stream for DidPoll<T> { - type Item = T::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll<Option<T::Item>> - { - self.did_poll = true; - Pin::new(&mut self.inner).poll_next(cx) - } - } - - for _ in 0..10 { - let mut map = task::spawn(StreamMap::new()); - let mut expect = 0; - - for (i, &is_empty) in kinds.iter().enumerate() { - let inner = if is_empty { - pin_box(stream::empty::<()>()) - } else { - expect += 1; - pin_box(stream::pending::<()>()) - }; - - let stream = DidPoll { - did_poll: false, - inner, - }; - - map.insert(i, stream); - } - - if expect == 0 { - assert_ready_none!(map.poll_next()); - } else { - assert_pending!(map.poll_next()); - - assert_eq!(expect, map.values().count()); - - for stream in map.values() { - assert!(stream.did_poll); - } - } - } - } -} - -fn pin_box<T: Stream<Item = U> + 'static, U>(s: T) -> Pin<Box<dyn Stream<Item = U>>> { - Box::pin(s) -} diff --git a/tests/stream_timeout.rs b/tests/stream_timeout.rs deleted file mode 100644 index a787bba..0000000 --- a/tests/stream_timeout.rs +++ /dev/null @@ -1,109 +0,0 @@ -#![cfg(feature = "full")] - -use tokio::stream::{self, StreamExt}; -use tokio::time::{self, sleep, Duration}; -use tokio_test::*; - -use futures::StreamExt as _; - -async fn maybe_sleep(idx: i32) -> i32 { - if idx % 2 == 0 { - sleep(ms(200)).await; - } - idx -} - -fn ms(n: u64) -> Duration { - Duration::from_millis(n) -} - -#[tokio::test] -async fn basic_usage() { - time::pause(); - - // Items 2 and 4 time out. If we run the stream until it completes, - // we end up with the following items: - // - // [Ok(1), Err(Elapsed), Ok(2), Ok(3), Err(Elapsed), Ok(4)] - - let stream = stream::iter(1..=4).then(maybe_sleep).timeout(ms(100)); - let mut stream = task::spawn(stream); - - // First item completes immediately - assert_ready_eq!(stream.poll_next(), Some(Ok(1))); - - // Second item is delayed 200ms, times out after 100ms - assert_pending!(stream.poll_next()); - - time::advance(ms(150)).await; - let v = assert_ready!(stream.poll_next()); - assert!(v.unwrap().is_err()); - - assert_pending!(stream.poll_next()); - - time::advance(ms(100)).await; - assert_ready_eq!(stream.poll_next(), Some(Ok(2))); - - // Third item is ready immediately - assert_ready_eq!(stream.poll_next(), Some(Ok(3))); - - // Fourth item is delayed 200ms, times out after 100ms - assert_pending!(stream.poll_next()); - - time::advance(ms(60)).await; - assert_pending!(stream.poll_next()); // nothing ready yet - - time::advance(ms(60)).await; - let v = assert_ready!(stream.poll_next()); - assert!(v.unwrap().is_err()); // timeout! - - time::advance(ms(120)).await; - assert_ready_eq!(stream.poll_next(), Some(Ok(4))); - - // Done. - assert_ready_eq!(stream.poll_next(), None); -} - -#[tokio::test] -async fn return_elapsed_errors_only_once() { - time::pause(); - - let stream = stream::iter(1..=3).then(maybe_sleep).timeout(ms(50)); - let mut stream = task::spawn(stream); - - // First item completes immediately - assert_ready_eq!(stream.poll_next(), Some(Ok(1))); - - // Second item is delayed 200ms, times out after 50ms. Only one `Elapsed` - // error is returned. - assert_pending!(stream.poll_next()); - // - time::advance(ms(50)).await; - let v = assert_ready!(stream.poll_next()); - assert!(v.unwrap().is_err()); // timeout! - - // deadline elapses again, but no error is returned - time::advance(ms(50)).await; - assert_pending!(stream.poll_next()); - - time::advance(ms(100)).await; - assert_ready_eq!(stream.poll_next(), Some(Ok(2))); - assert_ready_eq!(stream.poll_next(), Some(Ok(3))); - - // Done - assert_ready_eq!(stream.poll_next(), None); -} - -#[tokio::test] -async fn no_timeouts() { - let stream = stream::iter(vec![1, 3, 5]) - .then(maybe_sleep) - .timeout(ms(100)); - - let mut stream = task::spawn(stream); - - assert_ready_eq!(stream.poll_next(), Some(Ok(1))); - assert_ready_eq!(stream.poll_next(), Some(Ok(3))); - assert_ready_eq!(stream.poll_next(), Some(Ok(5))); - assert_ready_eq!(stream.poll_next(), None); -} diff --git a/tests/support/mpsc_stream.rs b/tests/support/mpsc_stream.rs new file mode 100644 index 0000000..aa385a3 --- /dev/null +++ b/tests/support/mpsc_stream.rs @@ -0,0 +1,42 @@ +#![allow(dead_code)] + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio_stream::Stream; + +struct UnboundedStream<T> { + recv: UnboundedReceiver<T>, +} +impl<T> Stream for UnboundedStream<T> { + type Item = T; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { + Pin::into_inner(self).recv.poll_recv(cx) + } +} + +pub fn unbounded_channel_stream<T: Unpin>() -> (UnboundedSender<T>, impl Stream<Item = T>) { + let (tx, rx) = mpsc::unbounded_channel(); + + let stream = UnboundedStream { recv: rx }; + + (tx, stream) +} + +struct BoundedStream<T> { + recv: Receiver<T>, +} +impl<T> Stream for BoundedStream<T> { + type Item = T; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> { + Pin::into_inner(self).recv.poll_recv(cx) + } +} + +pub fn channel_stream<T: Unpin>(size: usize) -> (Sender<T>, impl Stream<Item = T>) { + let (tx, rx) = mpsc::channel(size); + + let stream = BoundedStream { recv: rx }; + + (tx, stream) +} diff --git a/tests/sync_broadcast.rs b/tests/sync_broadcast.rs index 84c77a7..5f79800 100644 --- a/tests/sync_broadcast.rs +++ b/tests/sync_broadcast.rs @@ -89,46 +89,6 @@ fn send_two_recv() { assert_empty!(rx2); } -#[tokio::test] -async fn send_recv_into_stream_ready() { - use tokio::stream::StreamExt; - - let (tx, rx) = broadcast::channel::<i32>(8); - tokio::pin! { - let rx = rx.into_stream(); - } - - assert_ok!(tx.send(1)); - assert_ok!(tx.send(2)); - - assert_eq!(Some(Ok(1)), rx.next().await); - assert_eq!(Some(Ok(2)), rx.next().await); - - drop(tx); - - assert_eq!(None, rx.next().await); -} - -#[tokio::test] -async fn send_recv_into_stream_pending() { - use tokio::stream::StreamExt; - - let (tx, rx) = broadcast::channel::<i32>(8); - - tokio::pin! { - let rx = rx.into_stream(); - } - - let mut recv = task::spawn(rx.next()); - assert_pending!(recv.poll()); - - assert_ok!(tx.send(1)); - - assert!(recv.is_woken()); - let val = assert_ready!(recv.poll()); - assert_eq!(val, Some(Ok(1))); -} - #[test] fn send_recv_bounded() { let (tx, mut rx) = broadcast::channel(16); diff --git a/tests/sync_mpsc.rs b/tests/sync_mpsc.rs index adefcb1..b378e6b 100644 --- a/tests/sync_mpsc.rs +++ b/tests/sync_mpsc.rs @@ -5,7 +5,7 @@ use std::thread; use tokio::runtime::Runtime; use tokio::sync::mpsc; -use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; +use tokio::sync::mpsc::error::TrySendError; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, @@ -13,6 +13,10 @@ use tokio_test::{ use std::sync::Arc; +mod support { + pub(crate) mod mpsc_stream; +} + trait AssertSend: Send {} impl AssertSend for mpsc::Sender<i32> {} impl AssertSend for mpsc::Receiver<i32> {} @@ -80,9 +84,10 @@ async fn reserve_disarm() { #[tokio::test] async fn send_recv_stream_with_buffer() { - use tokio::stream::StreamExt; + use tokio_stream::StreamExt; - let (tx, mut rx) = mpsc::channel::<i32>(16); + let (tx, rx) = support::mpsc_stream::channel_stream::<i32>(16); + let mut rx = Box::pin(rx); tokio::spawn(async move { assert_ok!(tx.send(1).await); @@ -178,9 +183,11 @@ async fn async_send_recv_unbounded() { #[tokio::test] async fn send_recv_stream_unbounded() { - use tokio::stream::StreamExt; + use tokio_stream::StreamExt; - let (tx, mut rx) = mpsc::unbounded_channel::<i32>(); + let (tx, rx) = support::mpsc_stream::unbounded_channel_stream::<i32>(); + + let mut rx = Box::pin(rx); tokio::spawn(async move { assert_ok!(tx.send(1)); @@ -386,44 +393,6 @@ fn unconsumed_messages_are_dropped() { } #[test] -fn try_recv() { - let (tx, mut rx) = mpsc::channel(1); - match rx.try_recv() { - Err(TryRecvError::Empty) => {} - _ => panic!(), - } - tx.try_send(42).unwrap(); - match rx.try_recv() { - Ok(42) => {} - _ => panic!(), - } - drop(tx); - match rx.try_recv() { - Err(TryRecvError::Closed) => {} - _ => panic!(), - } -} - -#[test] -fn try_recv_unbounded() { - let (tx, mut rx) = mpsc::unbounded_channel(); - match rx.try_recv() { - Err(TryRecvError::Empty) => {} - _ => panic!(), - } - tx.send(42).unwrap(); - match rx.try_recv() { - Ok(42) => {} - _ => panic!(), - } - drop(tx); - match rx.try_recv() { - Err(TryRecvError::Closed) => {} - _ => panic!(), - } -} - -#[test] fn blocking_recv() { let (tx, mut rx) = mpsc::channel::<u8>(1); @@ -483,3 +452,22 @@ async fn ready_close_cancel_bounded() { let val = assert_ready!(recv.poll()); assert!(val.is_none()); } + +#[tokio::test] +async fn permit_available_not_acquired_close() { + let (tx1, mut rx) = mpsc::channel::<()>(1); + let tx2 = tx1.clone(); + + let permit1 = assert_ok!(tx1.reserve().await); + + let mut permit2 = task::spawn(tx2.reserve()); + assert_pending!(permit2.poll()); + + rx.close(); + + drop(permit1); + assert!(permit2.is_woken()); + + drop(permit2); + assert!(rx.recv().await.is_none()); +} diff --git a/tests/sync_mutex.rs b/tests/sync_mutex.rs index 96194b3..0ddb203 100644 --- a/tests/sync_mutex.rs +++ b/tests/sync_mutex.rs @@ -91,10 +91,11 @@ async fn aborted_future_1() { let m2 = m1.clone(); // Try to lock mutex in a future that is aborted prematurely timeout(Duration::from_millis(1u64), async move { - let mut iv = interval(Duration::from_millis(1000)); + let iv = interval(Duration::from_millis(1000)); + tokio::pin!(iv); m2.lock().await; - iv.tick().await; - iv.tick().await; + iv.as_mut().tick().await; + iv.as_mut().tick().await; }) .await .unwrap_err(); diff --git a/tests/sync_mutex_owned.rs b/tests/sync_mutex_owned.rs index 394a670..0f1399c 100644 --- a/tests/sync_mutex_owned.rs +++ b/tests/sync_mutex_owned.rs @@ -58,10 +58,11 @@ async fn aborted_future_1() { let m2 = m1.clone(); // Try to lock mutex in a future that is aborted prematurely timeout(Duration::from_millis(1u64), async move { - let mut iv = interval(Duration::from_millis(1000)); + let iv = interval(Duration::from_millis(1000)); + tokio::pin!(iv); m2.lock_owned().await; - iv.tick().await; - iv.tick().await; + iv.as_mut().tick().await; + iv.as_mut().tick().await; }) .await .unwrap_err(); diff --git a/tests/sync_semaphore.rs b/tests/sync_semaphore.rs index 1cb0c74..a33b878 100644 --- a/tests/sync_semaphore.rs +++ b/tests/sync_semaphore.rs @@ -79,3 +79,17 @@ async fn stresstest() { let _p5 = sem.try_acquire().unwrap(); assert!(sem.try_acquire().is_err()); } + +#[test] +fn add_max_amount_permits() { + let s = tokio::sync::Semaphore::new(0); + s.add_permits(usize::MAX >> 3); + assert_eq!(s.available_permits(), usize::MAX >> 3); +} + +#[test] +#[should_panic] +fn add_more_than_max_amount_permits() { + let s = tokio::sync::Semaphore::new(1); + s.add_permits(usize::MAX >> 3); +} diff --git a/tests/task_abort.rs b/tests/task_abort.rs new file mode 100644 index 0000000..e84f19c --- /dev/null +++ b/tests/task_abort.rs @@ -0,0 +1,26 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +/// Checks that a suspended task can be aborted without panicking as reported in +/// issue #3157: <https://github.com/tokio-rs/tokio/issues/3157>. +#[test] +fn test_abort_without_panic_3157() { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_time() + .worker_threads(1) + .build() + .unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + println!("task started"); + tokio::time::sleep(std::time::Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(std::time::Duration::new(1, 0)).await; + + handle.abort(); + let _ = handle.await; + }); +} diff --git a/tests/task_blocking.rs b/tests/task_blocking.rs index eec19cc..82bef8a 100644 --- a/tests/task_blocking.rs +++ b/tests/task_blocking.rs @@ -7,6 +7,10 @@ use tokio_test::assert_ok; use std::thread; use std::time::Duration; +mod support { + pub(crate) mod mpsc_stream; +} + #[tokio::test] async fn basic_blocking() { // Run a few times @@ -165,7 +169,8 @@ fn coop_disabled_in_block_in_place() { .build() .unwrap(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = support::mpsc_stream::unbounded_channel_stream(); + for i in 0..200 { tx.send(i).unwrap(); } @@ -175,7 +180,7 @@ fn coop_disabled_in_block_in_place() { let jh = tokio::spawn(async move { tokio::task::block_in_place(move || { futures::executor::block_on(async move { - use tokio::stream::StreamExt; + use tokio_stream::StreamExt; assert_eq!(rx.fold(0, |n, _| n + 1).await, 200); }) }) @@ -195,7 +200,8 @@ fn coop_disabled_in_block_in_place_in_block_on() { thread::spawn(move || { let outer = tokio::runtime::Runtime::new().unwrap(); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = support::mpsc_stream::unbounded_channel_stream(); + for i in 0..200 { tx.send(i).unwrap(); } @@ -204,7 +210,7 @@ fn coop_disabled_in_block_in_place_in_block_on() { outer.block_on(async move { tokio::task::block_in_place(move || { futures::executor::block_on(async move { - use tokio::stream::StreamExt; + use tokio_stream::StreamExt; assert_eq!(rx.fold(0, |n, _| n + 1).await, 200); }) }) diff --git a/tests/tcp_accept.rs b/tests/tcp_accept.rs index 4c0d682..5ffb946 100644 --- a/tests/tcp_accept.rs +++ b/tests/tcp_accept.rs @@ -46,7 +46,7 @@ use std::sync::{ Arc, }; use std::task::{Context, Poll}; -use tokio::stream::{Stream, StreamExt}; +use tokio_stream::{Stream, StreamExt}; struct TrackPolls<'a> { npolls: Arc<AtomicUsize>, @@ -88,7 +88,7 @@ async fn no_extra_poll() { assert_eq!(npolls.load(SeqCst), 1); let _ = assert_ok!(TcpStream::connect(&addr).await); - accepted_rx.next().await.unwrap(); + accepted_rx.recv().await.unwrap(); // should have been polled twice more: once to yield Some(), then once to yield Pending assert_eq!(npolls.load(SeqCst), 1 + 2); diff --git a/tests/tcp_connect.rs b/tests/tcp_connect.rs index 44942c4..cbe68fa 100644 --- a/tests/tcp_connect.rs +++ b/tests/tcp_connect.rs @@ -169,7 +169,7 @@ async fn connect_addr_host_str_port_tuple() { #[cfg(target_os = "linux")] mod linux { use tokio::net::{TcpListener, TcpStream}; - use tokio::prelude::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_test::assert_ok; use mio::unix::UnixReady; diff --git a/tests/tcp_echo.rs b/tests/tcp_echo.rs index d9cb456..5bb7ff0 100644 --- a/tests/tcp_echo.rs +++ b/tests/tcp_echo.rs @@ -1,8 +1,8 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::prelude::*; use tokio::sync::oneshot; use tokio_test::assert_ok; diff --git a/tests/tcp_into_std.rs b/tests/tcp_into_std.rs new file mode 100644 index 0000000..a46aace --- /dev/null +++ b/tests/tcp_into_std.rs @@ -0,0 +1,44 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::Read; +use std::io::Result; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; + +#[tokio::test] +async fn tcp_into_std() -> Result<()> { + let mut data = [0u8; 12]; + let listener = TcpListener::bind("127.0.0.1:34254").await?; + + let handle = tokio::spawn(async { + let stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + stream + }); + + let (tokio_tcp_stream, _) = listener.accept().await?; + let mut std_tcp_stream = tokio_tcp_stream.into_std()?; + std_tcp_stream + .set_nonblocking(false) + .expect("set_nonblocking call failed"); + + let mut client = handle.await.expect("The task being joined has panicked"); + client.write_all(b"Hello world!").await?; + + std_tcp_stream + .read_exact(&mut data) + .expect("std TcpStream read failed!"); + assert_eq!(b"Hello world!", &data); + + // test back to tokio stream + std_tcp_stream + .set_nonblocking(true) + .expect("set_nonblocking call failed"); + let mut tokio_tcp_stream = TcpStream::from_std(std_tcp_stream)?; + client.write_all(b"Hello tokio!").await?; + let _size = tokio_tcp_stream.read_exact(&mut data).await?; + assert_eq!(b"Hello tokio!", &data); + + Ok(()) +} diff --git a/tests/tcp_shutdown.rs b/tests/tcp_shutdown.rs index 615855f..536a161 100644 --- a/tests/tcp_shutdown.rs +++ b/tests/tcp_shutdown.rs @@ -1,9 +1,8 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{self, AsyncWriteExt}; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::prelude::*; use tokio_test::assert_ok; #[tokio::test] @@ -16,7 +15,7 @@ async fn shutdown() { assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); - let mut buf = [0; 1]; + let mut buf = [0u8; 1]; let n = assert_ok!(stream.read(&mut buf).await); assert_eq!(n, 0); }); diff --git a/tests/time_interval.rs b/tests/time_interval.rs index 5ac6ae6..a3c7f08 100644 --- a/tests/time_interval.rs +++ b/tests/time_interval.rs @@ -44,20 +44,6 @@ async fn usage() { assert_pending!(poll_next(&mut i)); } -#[tokio::test] -async fn usage_stream() { - use tokio::stream::StreamExt; - - let start = Instant::now(); - let mut interval = time::interval(ms(10)); - - for _ in 0..3 { - interval.next().await.unwrap(); - } - - assert!(start.elapsed() > ms(20)); -} - fn poll_next(interval: &mut task::Spawn<time::Interval>) -> Poll<Instant> { interval.enter(|cx, mut interval| { tokio::pin! { diff --git a/tests/time_pause.rs b/tests/time_pause.rs new file mode 100644 index 0000000..49a7677 --- /dev/null +++ b/tests/time_pause.rs @@ -0,0 +1,33 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio_test::assert_err; + +#[tokio::test] +async fn pause_time_in_main() { + tokio::time::pause(); +} + +#[tokio::test] +async fn pause_time_in_task() { + let t = tokio::spawn(async { + tokio::time::pause(); + }); + + t.await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[should_panic] +async fn pause_time_in_main_threads() { + tokio::time::pause(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn pause_time_in_spawn_threads() { + let t = tokio::spawn(async { + tokio::time::pause(); + }); + + assert_err!(t.await); +} diff --git a/tests/time_rt.rs b/tests/time_rt.rs index 85db78d..0775343 100644 --- a/tests/time_rt.rs +++ b/tests/time_rt.rs @@ -68,7 +68,7 @@ async fn starving() { } let when = Instant::now() + Duration::from_millis(20); - let starve = Starve(sleep_until(when), 0); + let starve = Starve(Box::pin(sleep_until(when)), 0); starve.await; assert!(Instant::now() >= when); diff --git a/tests/time_sleep.rs b/tests/time_sleep.rs index 955d833..2736258 100644 --- a/tests/time_sleep.rs +++ b/tests/time_sleep.rs @@ -1,6 +1,11 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use std::future::Future; +use std::task::Context; + +use futures::task::noop_waker_ref; + use tokio::time::{self, Duration, Instant}; use tokio_test::{assert_pending, assert_ready, task}; @@ -31,6 +36,25 @@ async fn immediate_sleep() { } #[tokio::test] +async fn is_elapsed() { + time::pause(); + + let sleep = time::sleep(Duration::from_millis(50)); + + tokio::pin!(sleep); + + assert!(!sleep.is_elapsed()); + + assert!(futures::poll!(sleep.as_mut()).is_pending()); + + assert!(!sleep.is_elapsed()); + + sleep.as_mut().await; + + assert!(sleep.is_elapsed()); +} + +#[tokio::test] async fn delayed_sleep_level_0() { time::pause(); @@ -75,12 +99,12 @@ async fn reset_future_sleep_before_fire() { let now = Instant::now(); - let mut sleep = task::spawn(time::sleep_until(now + ms(100))); + let mut sleep = task::spawn(Box::pin(time::sleep_until(now + ms(100)))); assert_pending!(sleep.poll()); let mut sleep = sleep.into_inner(); - sleep.reset(Instant::now() + ms(200)); + sleep.as_mut().reset(Instant::now() + ms(200)); sleep.await; assert_elapsed!(now, 200); @@ -92,12 +116,12 @@ async fn reset_past_sleep_before_turn() { let now = Instant::now(); - let mut sleep = task::spawn(time::sleep_until(now + ms(100))); + let mut sleep = task::spawn(Box::pin(time::sleep_until(now + ms(100)))); assert_pending!(sleep.poll()); let mut sleep = sleep.into_inner(); - sleep.reset(now + ms(80)); + sleep.as_mut().reset(now + ms(80)); sleep.await; assert_elapsed!(now, 80); @@ -109,14 +133,14 @@ async fn reset_past_sleep_before_fire() { let now = Instant::now(); - let mut sleep = task::spawn(time::sleep_until(now + ms(100))); + let mut sleep = task::spawn(Box::pin(time::sleep_until(now + ms(100)))); assert_pending!(sleep.poll()); let mut sleep = sleep.into_inner(); time::sleep(ms(10)).await; - sleep.reset(now + ms(80)); + sleep.as_mut().reset(now + ms(80)); sleep.await; assert_elapsed!(now, 80); @@ -127,12 +151,12 @@ async fn reset_future_sleep_after_fire() { time::pause(); let now = Instant::now(); - let mut sleep = time::sleep_until(now + ms(100)); + let mut sleep = Box::pin(time::sleep_until(now + ms(100))); - (&mut sleep).await; + sleep.as_mut().await; assert_elapsed!(now, 100); - sleep.reset(now + ms(110)); + sleep.as_mut().reset(now + ms(110)); sleep.await; assert_elapsed!(now, 110); } @@ -143,16 +167,17 @@ async fn reset_sleep_to_past() { let now = Instant::now(); - let mut sleep = task::spawn(time::sleep_until(now + ms(100))); + let mut sleep = task::spawn(Box::pin(time::sleep_until(now + ms(100)))); assert_pending!(sleep.poll()); time::sleep(ms(50)).await; assert!(!sleep.is_woken()); - sleep.reset(now + ms(40)); + sleep.as_mut().reset(now + ms(40)); - assert!(sleep.is_woken()); + // TODO: is this required? + //assert!(sleep.is_woken()); assert_ready!(sleep.poll()); } @@ -167,22 +192,110 @@ fn creating_sleep_outside_of_context() { let _fut = time::sleep_until(now + ms(500)); } -#[should_panic] #[tokio::test] async fn greater_than_max() { const YR_5: u64 = 5 * 365 * 24 * 60 * 60 * 1000; + time::pause(); time::sleep_until(Instant::now() + ms(YR_5)).await; } +#[tokio::test] +async fn short_sleeps() { + for i in 0..10000 { + if (i % 10) == 0 { + eprintln!("=== {}", i); + } + tokio::time::sleep(std::time::Duration::from_millis(0)).await; + } +} + +#[tokio::test] +async fn multi_long_sleeps() { + tokio::time::pause(); + + for _ in 0..5u32 { + tokio::time::sleep(Duration::from_secs( + // about a year + 365 * 24 * 3600, + )) + .await; + } + + let deadline = tokio::time::Instant::now() + + Duration::from_secs( + // about 10 years + 10 * 365 * 24 * 3600, + ); + + tokio::time::sleep_until(deadline).await; + + assert!(tokio::time::Instant::now() >= deadline); +} + +#[tokio::test] +async fn long_sleeps() { + tokio::time::pause(); + + let deadline = tokio::time::Instant::now() + + Duration::from_secs( + // about 10 years + 10 * 365 * 24 * 3600, + ); + + tokio::time::sleep_until(deadline).await; + + assert!(tokio::time::Instant::now() >= deadline); + assert!(tokio::time::Instant::now() <= deadline + Duration::from_millis(1)); +} + +#[tokio::test] +#[should_panic(expected = "Duration too far into the future")] +async fn very_long_sleeps() { + tokio::time::pause(); + + // Some platforms (eg macos) can't represent times this far in the future + if let Some(deadline) = tokio::time::Instant::now().checked_add(Duration::from_secs(1u64 << 62)) + { + tokio::time::sleep_until(deadline).await; + } else { + // make it pass anyway (we can't skip/ignore the test based on the + // result of checked_add) + panic!("Duration too far into the future (test ignored)") + } +} + +#[tokio::test] +async fn reset_after_firing() { + let timer = tokio::time::sleep(std::time::Duration::from_millis(1)); + tokio::pin!(timer); + + let deadline = timer.deadline(); + + timer.as_mut().await; + assert_ready!(timer + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref()))); + timer + .as_mut() + .reset(tokio::time::Instant::now() + std::time::Duration::from_secs(600)); + + assert_ne!(deadline, timer.deadline()); + + assert_pending!(timer + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref()))); + assert_pending!(timer + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref()))); +} + const NUM_LEVELS: usize = 6; const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; -#[should_panic] #[tokio::test] async fn exactly_max() { - // TODO: this should not panic but `time::ms()` is acting up - // If fixed, make sure to update documentation on `time::sleep` too. + time::pause(); time::sleep(ms(MAX_DURATION)).await; } @@ -195,3 +308,79 @@ async fn no_out_of_bounds_close_to_max() { fn ms(n: u64) -> Duration { Duration::from_millis(n) } + +#[tokio::test] +async fn drop_after_reschedule_at_new_scheduled_time() { + use futures::poll; + + tokio::time::pause(); + + let start = tokio::time::Instant::now(); + + let mut a = Box::pin(tokio::time::sleep(Duration::from_millis(5))); + let mut b = Box::pin(tokio::time::sleep(Duration::from_millis(5))); + let mut c = Box::pin(tokio::time::sleep(Duration::from_millis(10))); + + let _ = poll!(&mut a); + let _ = poll!(&mut b); + let _ = poll!(&mut c); + + b.as_mut().reset(start + Duration::from_millis(10)); + a.await; + + drop(b); +} + +#[tokio::test] +async fn drop_from_wake() { + use std::future::Future; + use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + use std::task::Context; + + let panicked = Arc::new(AtomicBool::new(false)); + let list: Arc<Mutex<Vec<Pin<Box<tokio::time::Sleep>>>>> = Arc::new(Mutex::new(Vec::new())); + + let arc_wake = Arc::new(DropWaker(panicked.clone(), list.clone())); + let arc_wake = futures::task::waker(arc_wake); + + tokio::time::pause(); + + let mut lock = list.lock().unwrap(); + + for _ in 0..100 { + let mut timer = Box::pin(tokio::time::sleep(Duration::from_millis(10))); + + let _ = timer.as_mut().poll(&mut Context::from_waker(&arc_wake)); + + lock.push(timer); + } + + drop(lock); + + tokio::time::sleep(Duration::from_millis(11)).await; + + assert!( + !panicked.load(Ordering::SeqCst), + "paniced when dropping timers" + ); + + #[derive(Clone)] + struct DropWaker( + Arc<AtomicBool>, + Arc<Mutex<Vec<Pin<Box<tokio::time::Sleep>>>>>, + ); + + impl futures::task::ArcWake for DropWaker { + fn wake_by_ref(arc_self: &Arc<Self>) { + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + *arc_self.1.lock().expect("panic in lock") = Vec::new() + })); + + if result.is_err() { + arc_self.0.store(true, Ordering::SeqCst); + } + } + } +} diff --git a/tests/time_throttle.rs b/tests/time_throttle.rs deleted file mode 100644 index c886319..0000000 --- a/tests/time_throttle.rs +++ /dev/null @@ -1,28 +0,0 @@ -#![warn(rust_2018_idioms)] -#![cfg(feature = "full")] - -use tokio::stream::StreamExt; -use tokio::time; -use tokio_test::*; - -use std::time::Duration; - -#[tokio::test] -async fn usage() { - time::pause(); - - let mut stream = task::spawn(futures::stream::repeat(()).throttle(Duration::from_millis(100))); - - assert_ready!(stream.poll_next()); - assert_pending!(stream.poll_next()); - - time::advance(Duration::from_millis(90)).await; - - assert_pending!(stream.poll_next()); - - time::advance(Duration::from_millis(101)).await; - - assert!(stream.is_woken()); - - assert_ready!(stream.poll_next()); -} diff --git a/tests/udp.rs b/tests/udp.rs index 291267e..7cbba1b 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -66,7 +66,7 @@ async fn send_to_recv_from_poll() -> std::io::Result<()> { let receiver = UdpSocket::bind("127.0.0.1:0").await?; let receiver_addr = receiver.local_addr()?; - poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, receiver_addr)).await?; let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); @@ -83,7 +83,7 @@ async fn send_to_peek_from() -> std::io::Result<()> { let receiver = UdpSocket::bind("127.0.0.1:0").await?; let receiver_addr = receiver.local_addr()?; - poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, receiver_addr)).await?; // peek let mut recv_buf = [0u8; 32]; @@ -111,7 +111,7 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> { let receiver = UdpSocket::bind("127.0.0.1:0").await?; let receiver_addr = receiver.local_addr()?; - poll_fn(|cx| sender.poll_send_to(cx, MSG, &receiver_addr)).await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, receiver_addr)).await?; let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); @@ -192,7 +192,7 @@ async fn split_chan_poll() -> std::io::Result<()> { let (tx, mut rx) = tokio::sync::mpsc::channel::<(Vec<u8>, std::net::SocketAddr)>(1_000); tokio::spawn(async move { while let Some((bytes, addr)) = rx.recv().await { - poll_fn(|cx| s.poll_send_to(cx, &bytes, &addr)) + poll_fn(|cx| s.poll_send_to(cx, &bytes, addr)) .await .unwrap(); } @@ -209,7 +209,7 @@ async fn split_chan_poll() -> std::io::Result<()> { // test that we can send a value and get back some response let sender = UdpSocket::bind("127.0.0.1:0").await?; - poll_fn(|cx| sender.poll_send_to(cx, MSG, &addr)).await?; + poll_fn(|cx| sender.poll_send_to(cx, MSG, addr)).await?; let mut recv_buf = [0u8; 32]; let mut read = ReadBuf::new(&mut recv_buf); diff --git a/tests/uds_datagram.rs b/tests/uds_datagram.rs index ec2f6f8..cdabd7b 100644 --- a/tests/uds_datagram.rs +++ b/tests/uds_datagram.rs @@ -2,6 +2,8 @@ #![cfg(feature = "full")] #![cfg(unix)] +use futures::future::poll_fn; +use tokio::io::ReadBuf; use tokio::net::UnixDatagram; use tokio::try_join; @@ -82,6 +84,8 @@ async fn try_send_recv_never_block() -> io::Result<()> { // Send until we hit the OS `net.unix.max_dgram_qlen`. loop { + dgram1.writable().await.unwrap(); + match dgram1.try_send(payload) { Err(err) => match err.kind() { io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, @@ -96,6 +100,7 @@ async fn try_send_recv_never_block() -> io::Result<()> { // Read every dgram we sent. while count > 0 { + dgram2.readable().await.unwrap(); let len = dgram2.try_recv(&mut recv_buf[..])?; assert_eq!(len, payload.len()); assert_eq!(payload, &recv_buf[..len]); @@ -134,3 +139,94 @@ async fn split() -> std::io::Result<()> { Ok(()) } + +#[tokio::test] +async fn send_to_recv_from_poll() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let sender_path = dir.path().join("sender.sock"); + let receiver_path = dir.path().join("receiver.sock"); + + let sender = UnixDatagram::bind(&sender_path)?; + let receiver = UnixDatagram::bind(&receiver_path)?; + + let msg = b"hello"; + poll_fn(|cx| sender.poll_send_to(cx, msg, &receiver_path)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let addr = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), msg); + assert_eq!(addr.as_pathname(), Some(sender_path.as_ref())); + Ok(()) +} + +#[tokio::test] +async fn send_recv_poll() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let sender_path = dir.path().join("sender.sock"); + let receiver_path = dir.path().join("receiver.sock"); + + let sender = UnixDatagram::bind(&sender_path)?; + let receiver = UnixDatagram::bind(&receiver_path)?; + + sender.connect(&receiver_path)?; + receiver.connect(&sender_path)?; + + let msg = b"hello"; + poll_fn(|cx| sender.poll_send(cx, msg)).await?; + + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let _len = poll_fn(|cx| receiver.poll_recv(cx, &mut read)).await?; + + assert_eq!(read.filled(), msg); + Ok(()) +} + +#[tokio::test] +async fn try_send_to_recv_from() -> std::io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let server_path = dir.path().join("server.sock"); + let client_path = dir.path().join("client.sock"); + + // Create listener + let server = UnixDatagram::bind(&server_path)?; + + // Create socket pair + let client = UnixDatagram::bind(&client_path)?; + + for _ in 0..5 { + loop { + client.writable().await?; + + match client.try_send_to(b"hello world", &server_path) { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + server.readable().await?; + + let mut buf = [0; 512]; + + match server.try_recv_from(&mut buf) { + Ok((n, addr)) => { + assert_eq!(n, 11); + assert_eq!(addr.as_pathname(), Some(client_path.as_ref())); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } + + Ok(()) +} diff --git a/tests/uds_split.rs b/tests/uds_split.rs index 76ff461..8161423 100644 --- a/tests/uds_split.rs +++ b/tests/uds_split.rs @@ -2,8 +2,8 @@ #![cfg(feature = "full")] #![cfg(unix)] +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::UnixStream; -use tokio::prelude::*; /// Checks that `UnixStream` can be split into a read half and a write half using /// `UnixStream::split` and `UnixStream::split_mut`. diff --git a/tests/uds_stream.rs b/tests/uds_stream.rs index cd557e5..5160f17 100644 --- a/tests/uds_stream.rs +++ b/tests/uds_stream.rs @@ -2,10 +2,14 @@ #![warn(rust_2018_idioms)] #![cfg(unix)] -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use std::io; +use std::task::Poll; + +use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{UnixListener, UnixStream}; +use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; -use futures::future::try_join; +use futures::future::{poll_fn, try_join}; #[tokio::test] async fn accept_read_write() -> std::io::Result<()> { @@ -56,3 +60,195 @@ async fn shutdown() -> std::io::Result<()> { assert_eq!(n, 0); Ok(()) } + +#[tokio::test] +async fn try_read_write() -> std::io::Result<()> { + let msg = b"hello world"; + + let dir = tempfile::tempdir()?; + let bind_path = dir.path().join("bind.sock"); + + // Create listener + let listener = UnixListener::bind(&bind_path)?; + + // Create socket pair + let client = UnixStream::connect(&bind_path).await?; + + let (server, _) = listener.accept().await?; + let mut written = msg.to_vec(); + + // Track the server receiving data + let mut readable = task::spawn(server.readable()); + assert_pending!(readable.poll()); + + // Write data. + client.writable().await?; + assert_eq!(msg.len(), client.try_write(msg)?); + + // The task should be notified + while !readable.is_woken() { + tokio::task::yield_now().await; + } + + // Fill the write buffer + loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + + match client.try_write(msg) { + Ok(n) => written.extend(&msg[..n]), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => panic!("error = {:?}", e), + } + } + + { + // Write buffer full + let mut writable = task::spawn(client.writable()); + assert_pending!(writable.poll()); + + // Drain the socket from the server end + let mut read = vec![0; written.len()]; + let mut i = 0; + + while i < read.len() { + server.readable().await?; + + match server.try_read(&mut read[i..]) { + Ok(n) => i += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("error = {:?}", e), + } + } + + assert_eq!(read, written); + } + + // Now, we listen for shutdown + drop(client); + + loop { + let ready = server.ready(Interest::READABLE).await?; + + if ready.is_read_closed() { + break; + } else { + tokio::task::yield_now().await; + } + } + + Ok(()) +} + +async fn create_pair() -> (UnixStream, UnixStream) { + let dir = assert_ok!(tempfile::tempdir()); + let bind_path = dir.path().join("bind.sock"); + + let listener = assert_ok!(UnixListener::bind(&bind_path)); + + let accept = listener.accept(); + let connect = UnixStream::connect(&bind_path); + let ((server, _), client) = assert_ok!(try_join(accept, connect).await); + + (client, server) +} + +macro_rules! assert_readable_by_polling { + ($stream:expr) => { + assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await); + }; +} + +macro_rules! assert_not_readable_by_polling { + ($stream:expr) => { + poll_fn(|cx| { + assert_pending!($stream.poll_read_ready(cx)); + Poll::Ready(()) + }) + .await; + }; +} + +macro_rules! assert_writable_by_polling { + ($stream:expr) => { + assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await); + }; +} + +macro_rules! assert_not_writable_by_polling { + ($stream:expr) => { + poll_fn(|cx| { + assert_pending!($stream.poll_write_ready(cx)); + Poll::Ready(()) + }) + .await; + }; +} + +#[tokio::test] +async fn poll_read_ready() { + let (mut client, mut server) = create_pair().await; + + // Initial state - not readable. + assert_not_readable_by_polling!(server); + + // There is data in the buffer - readable. + assert_ok!(client.write_all(b"ping").await); + assert_readable_by_polling!(server); + + // Readable until calls to `poll_read` return `Poll::Pending`. + let mut buf = [0u8; 4]; + assert_ok!(server.read_exact(&mut buf).await); + assert_readable_by_polling!(server); + read_until_pending(&mut server); + assert_not_readable_by_polling!(server); + + // Detect the client disconnect. + drop(client); + assert_readable_by_polling!(server); +} + +#[tokio::test] +async fn poll_write_ready() { + let (mut client, server) = create_pair().await; + + // Initial state - writable. + assert_writable_by_polling!(client); + + // No space to write - not writable. + write_until_pending(&mut client); + assert_not_writable_by_polling!(client); + + // Detect the server disconnect. + drop(server); + assert_writable_by_polling!(client); +} + +fn read_until_pending(stream: &mut UnixStream) { + let mut buf = vec![0u8; 1024 * 1024]; + loop { + match stream.try_read(&mut buf) { + Ok(_) => (), + Err(err) => { + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + break; + } + } + } +} + +fn write_until_pending(stream: &mut UnixStream) { + let buf = vec![0u8; 1024 * 1024]; + loop { + match stream.try_write(&buf) { + Ok(_) => (), + Err(err) => { + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + break; + } + } + } +} |