aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Vander Stoep <jeffv@google.com>2024-02-05 13:21:41 +0100
committerJeff Vander Stoep <jeffv@google.com>2024-02-05 13:21:42 +0100
commita7879fe08376d6c078eed44ffd47a14fb340af59 (patch)
tree3044f4002ed695e6540a4969306b4304349073ce
parent95367cb21fc8db087932e69f763904e32f80c6e6 (diff)
downloadtokio-util-emu-34-3-release.tar.gz
Upgrade tokio-util to 0.7.10emu-34-3-release
This project was upgraded with external_updater. Usage: tools/external_updater/updater.sh update external/rust/crates/tokio-util For more info, check https://cs.android.com/android/platform/superproject/+/main:tools/external_updater/README.md Test: TreeHugger Change-Id: Id389cb27cded67965187740c44c6710c94cd42ae
-rw-r--r--.cargo_vcs_info.json6
-rw-r--r--Android.bp4
-rw-r--r--CHANGELOG.md72
-rw-r--r--Cargo.toml19
-rw-r--r--Cargo.toml.orig13
-rw-r--r--METADATA25
-rw-r--r--src/codec/lines_codec.rs2
-rw-r--r--src/compat.rs6
-rw-r--r--src/either.rs2
-rw-r--r--src/io/copy_to_bytes.rs8
-rw-r--r--src/io/inspect.rs46
-rw-r--r--src/io/sink_writer.rs22
-rw-r--r--src/io/stream_reader.rs22
-rw-r--r--src/io/sync_bridge.rs17
-rw-r--r--src/lib.rs149
-rw-r--r--src/sync/cancellation_token.rs35
-rw-r--r--src/sync/cancellation_token/tree_node.rs80
-rw-r--r--src/sync/mpsc.rs59
-rw-r--r--src/sync/poll_semaphore.rs2
-rw-r--r--src/sync/reusable_box.rs18
-rw-r--r--src/task/join_map.rs97
-rw-r--r--src/task/mod.rs5
-rw-r--r--src/task/task_tracker.rs719
-rw-r--r--src/time/delay_queue.rs41
-rw-r--r--src/time/wheel/level.rs30
-rw-r--r--src/time/wheel/mod.rs17
-rw-r--r--src/time/wheel/stack.rs3
-rw-r--r--src/util/maybe_dangling.rs67
-rw-r--r--src/util/mod.rs8
-rw-r--r--src/util/poll_buf.rs145
-rw-r--r--tests/compat.rs43
-rw-r--r--tests/io_sync_bridge.rs12
-rw-r--r--tests/length_delimited.rs91
-rw-r--r--tests/mpsc.rs23
-rw-r--r--tests/task_join_map.rs24
-rw-r--r--tests/task_tracker.rs178
-rw-r--r--tests/time_delay_queue.rs63
-rw-r--r--tests/udp.rs7
38 files changed, 1862 insertions, 318 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
new file mode 100644
index 0000000..a9990be
--- /dev/null
+++ b/.cargo_vcs_info.json
@@ -0,0 +1,6 @@
+{
+ "git": {
+ "sha1": "503fad79087ed5791c7a018e07621689ea5e4676"
+ },
+ "path_in_vcs": "tokio-util"
+} \ No newline at end of file
diff --git a/Android.bp b/Android.bp
index ef96022..4d19f54 100644
--- a/Android.bp
+++ b/Android.bp
@@ -23,9 +23,9 @@ rust_library {
host_supported: true,
crate_name: "tokio_util",
cargo_env_compat: true,
- cargo_pkg_version: "0.7.7",
+ cargo_pkg_version: "0.7.10",
srcs: ["src/lib.rs"],
- edition: "2018",
+ edition: "2021",
features: [
"codec",
"compat",
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0c11b21..b98092c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,75 @@
+# 0.7.10 (October 24th, 2023)
+
+### Added
+
+- task: add `TaskTracker` ([#6033])
+- task: add `JoinMap::keys` ([#6046])
+- io: implement `Seek` for `SyncIoBridge` ([#6058])
+
+### Changed
+
+- deps: update hashbrown to 0.14 ([#6102])
+
+[#6033]: https://github.com/tokio-rs/tokio/pull/6033
+[#6046]: https://github.com/tokio-rs/tokio/pull/6046
+[#6058]: https://github.com/tokio-rs/tokio/pull/6058
+[#6102]: https://github.com/tokio-rs/tokio/pull/6102
+
+# 0.7.9 (September 20th, 2023)
+
+### Added
+
+- io: add passthrough `AsyncRead`/`AsyncWrite` to `InspectWriter`/`InspectReader` ([#5739])
+- task: add spawn blocking methods to `JoinMap` ([#5797])
+- io: pass through traits for `StreamReader` and `SinkWriter` ([#5941])
+- io: add `SyncIoBridge::into_inner` ([#5971])
+
+### Fixed
+
+- sync: handle possibly dangling reference safely ([#5812])
+- util: fix broken intra-doc link ([#5849])
+- compat: fix clippy warnings ([#5891])
+
+### Documented
+
+- codec: Specify the line ending of `LinesCodec` ([#5982])
+
+[#5739]: https://github.com/tokio-rs/tokio/pull/5739
+[#5797]: https://github.com/tokio-rs/tokio/pull/5797
+[#5941]: https://github.com/tokio-rs/tokio/pull/5941
+[#5971]: https://github.com/tokio-rs/tokio/pull/5971
+[#5812]: https://github.com/tokio-rs/tokio/pull/5812
+[#5849]: https://github.com/tokio-rs/tokio/pull/5849
+[#5891]: https://github.com/tokio-rs/tokio/pull/5891
+[#5982]: https://github.com/tokio-rs/tokio/pull/5982
+
+# 0.7.8 (April 25th, 2023)
+
+This release bumps the MSRV of tokio-util to 1.56.
+
+### Added
+
+- time: add `DelayQueue::peek` ([#5569])
+
+### Changed
+
+This release contains one performance improvement:
+
+- sync: try to lock the parent first in `CancellationToken` ([#5561])
+
+### Fixed
+
+- time: fix panic in `DelayQueue` ([#5630])
+
+### Documented
+
+- sync: improve `CancellationToken` doc on child tokens ([#5632])
+
+[#5561]: https://github.com/tokio-rs/tokio/pull/5561
+[#5569]: https://github.com/tokio-rs/tokio/pull/5569
+[#5630]: https://github.com/tokio-rs/tokio/pull/5630
+[#5632]: https://github.com/tokio-rs/tokio/pull/5632
+
# 0.7.7 (February 12, 2023)
This release reverts the removal of the `Encoder` bound on the `FramedParts`
diff --git a/Cargo.toml b/Cargo.toml
index 221f2e7..51e9a96 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -10,10 +10,10 @@
# See Cargo.toml.orig for the original contents.
[package]
-edition = "2018"
-rust-version = "1.49"
+edition = "2021"
+rust-version = "1.56"
name = "tokio-util"
-version = "0.7.7"
+version = "0.7.10"
authors = ["Tokio Contributors <team@tokio.rs>"]
description = """
Additional utilities for working with Tokio.
@@ -26,13 +26,13 @@ repository = "https://github.com/tokio-rs/tokio"
[package.metadata.docs.rs]
all-features = true
-rustdoc-args = [
+rustc-args = [
"--cfg",
"docsrs",
"--cfg",
"tokio_unstable",
]
-rustc-args = [
+rustdoc-args = [
"--cfg",
"docsrs",
"--cfg",
@@ -57,14 +57,14 @@ version = "0.3.0"
optional = true
[dependencies.pin-project-lite]
-version = "0.2.0"
+version = "0.2.11"
[dependencies.slab]
version = "0.4.4"
optional = true
[dependencies.tokio]
-version = "1.22.0"
+version = "1.28.0"
features = ["sync"]
[dependencies.tracing]
@@ -85,6 +85,9 @@ version = "0.3.5"
[dev-dependencies.parking_lot]
version = "0.12.0"
+[dev-dependencies.tempfile]
+version = "3.1.0"
+
[dev-dependencies.tokio]
version = "1.0.0"
features = ["full"]
@@ -127,5 +130,5 @@ time = [
]
[target."cfg(tokio_unstable)".dependencies.hashbrown]
-version = "0.12.0"
+version = "0.14.0"
optional = true
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 267662b..437dc5a 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -4,9 +4,9 @@ name = "tokio-util"
# - Remove path dependencies
# - Update CHANGELOG.md.
# - Create "tokio-util-0.7.x" git tag.
-version = "0.7.7"
-edition = "2018"
-rust-version = "1.49"
+version = "0.7.10"
+edition = "2021"
+rust-version = "1.56"
authors = ["Tokio Contributors <team@tokio.rs>"]
license = "MIT"
repository = "https://github.com/tokio-rs/tokio"
@@ -34,18 +34,18 @@ rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"]
__docs_rs = ["futures-util"]
[dependencies]
-tokio = { version = "1.22.0", path = "../tokio", features = ["sync"] }
+tokio = { version = "1.28.0", path = "../tokio", features = ["sync"] }
bytes = "1.0.0"
futures-core = "0.3.0"
futures-sink = "0.3.0"
futures-io = { version = "0.3.0", optional = true }
futures-util = { version = "0.3.0", optional = true }
-pin-project-lite = "0.2.0"
+pin-project-lite = "0.2.11"
slab = { version = "0.4.4", optional = true } # Backs `DelayQueue`
tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true }
[target.'cfg(tokio_unstable)'.dependencies]
-hashbrown = { version = "0.12.0", optional = true }
+hashbrown = { version = "0.14.0", optional = true }
[dev-dependencies]
tokio = { version = "1.0.0", path = "../tokio", features = ["full"] }
@@ -56,6 +56,7 @@ async-stream = "0.3.0"
futures = "0.3.0"
futures-test = "0.3.5"
parking_lot = "0.12.0"
+tempfile = "3.1.0"
[package.metadata.docs.rs]
all-features = true
diff --git a/METADATA b/METADATA
index 28dbb73..ffaab87 100644
--- a/METADATA
+++ b/METADATA
@@ -1,19 +1,20 @@
+# This project was upgraded with external_updater.
+# Usage: tools/external_updater/updater.sh update external/rust/crates/tokio-util
+# For more info, check https://cs.android.com/android/platform/superproject/+/main:tools/external_updater/README.md
+
name: "tokio-util"
description: "Utilities for working with Tokio."
third_party {
- url {
- type: HOMEPAGE
- value: "https://crates.io/crates/tokio-util"
- }
- url {
- type: ARCHIVE
- value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.7.crate"
- }
- version: "0.7.7"
license_type: NOTICE
last_upgrade_date {
- year: 2023
- month: 3
- day: 3
+ year: 2024
+ month: 2
+ day: 5
+ }
+ homepage: "https://crates.io/crates/tokio-util"
+ identifier {
+ type: "Archive"
+ value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.10.crate"
+ version: "0.7.10"
}
}
diff --git a/src/codec/lines_codec.rs b/src/codec/lines_codec.rs
index 7a0a8f0..5a6035d 100644
--- a/src/codec/lines_codec.rs
+++ b/src/codec/lines_codec.rs
@@ -6,6 +6,8 @@ use std::{cmp, fmt, io, str, usize};
/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines.
///
+/// This uses the `\n` character as the line ending on all platforms.
+///
/// [`Decoder`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
diff --git a/src/compat.rs b/src/compat.rs
index 6a8802d..423bd95 100644
--- a/src/compat.rs
+++ b/src/compat.rs
@@ -227,12 +227,14 @@ impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
pos: io::SeekFrom,
) -> Poll<io::Result<u64>> {
if self.seek_pos != Some(pos) {
+ // Ensure previous seeks have finished before starting a new one
+ ready!(self.as_mut().project().inner.poll_complete(cx))?;
self.as_mut().project().inner.start_seek(pos)?;
*self.as_mut().project().seek_pos = Some(pos);
}
let res = ready!(self.as_mut().project().inner.poll_complete(cx));
*self.as_mut().project().seek_pos = None;
- Poll::Ready(res.map(|p| p as u64))
+ Poll::Ready(res)
}
}
@@ -255,7 +257,7 @@ impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
};
let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
*self.as_mut().project().seek_pos = None;
- Poll::Ready(res.map(|p| p as u64))
+ Poll::Ready(res)
}
}
diff --git a/src/either.rs b/src/either.rs
index 9225e53..8a02398 100644
--- a/src/either.rs
+++ b/src/either.rs
@@ -116,7 +116,7 @@ where
}
fn consume(self: Pin<&mut Self>, amt: usize) {
- delegate_call!(self.consume(amt))
+ delegate_call!(self.consume(amt));
}
}
diff --git a/src/io/copy_to_bytes.rs b/src/io/copy_to_bytes.rs
index 9509e71..f0b5c35 100644
--- a/src/io/copy_to_bytes.rs
+++ b/src/io/copy_to_bytes.rs
@@ -1,4 +1,5 @@
use bytes::Bytes;
+use futures_core::stream::Stream;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::pin::Pin;
@@ -66,3 +67,10 @@ where
self.project().inner.poll_close(cx)
}
}
+
+impl<S: Stream> Stream for CopyToBytes<S> {
+ type Item = S::Item;
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
diff --git a/src/io/inspect.rs b/src/io/inspect.rs
index ec5bb97..c860b80 100644
--- a/src/io/inspect.rs
+++ b/src/io/inspect.rs
@@ -52,6 +52,42 @@ impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
}
}
+impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::result::Result<usize, std::io::Error>> {
+ self.project().reader.poll_write(cx, buf)
+ }
+
+ fn poll_flush(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<std::result::Result<(), std::io::Error>> {
+ self.project().reader.poll_flush(cx)
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<std::result::Result<(), std::io::Error>> {
+ self.project().reader.poll_shutdown(cx)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[IoSlice<'_>],
+ ) -> Poll<Result<usize>> {
+ self.project().reader.poll_write_vectored(cx, bufs)
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ self.reader.is_write_vectored()
+ }
+}
+
pin_project! {
/// An adapter that lets you inspect the data that's being written.
///
@@ -132,3 +168,13 @@ impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
self.writer.is_write_vectored()
}
}
+
+impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ self.project().writer.poll_read(cx, buf)
+ }
+}
diff --git a/src/io/sink_writer.rs b/src/io/sink_writer.rs
index f2af262..e078952 100644
--- a/src/io/sink_writer.rs
+++ b/src/io/sink_writer.rs
@@ -1,11 +1,12 @@
use futures_core::ready;
use futures_sink::Sink;
+use futures_core::stream::Stream;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
-use tokio::io::AsyncWrite;
+use tokio::io::{AsyncRead, AsyncWrite};
pin_project! {
/// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`].
@@ -59,7 +60,7 @@ pin_project! {
/// [`CopyToBytes`]: crate::io::CopyToBytes
/// [`Encoder`]: crate::codec::Encoder
/// [`Sink`]: futures_sink::Sink
- /// [`codec`]: tokio_util::codec
+ /// [`codec`]: crate::codec
#[derive(Debug)]
pub struct SinkWriter<S> {
#[pin]
@@ -115,3 +116,20 @@ where
self.project().inner.poll_close(cx).map_err(Into::into)
}
}
+
+impl<S: Stream> Stream for SinkWriter<S> {
+ type Item = S::Item;
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
+
+impl<S: AsyncRead> AsyncRead for SinkWriter<S> {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.project().inner.poll_read(cx, buf)
+ }
+}
diff --git a/src/io/stream_reader.rs b/src/io/stream_reader.rs
index 3353722..6ecf8ec 100644
--- a/src/io/stream_reader.rs
+++ b/src/io/stream_reader.rs
@@ -1,5 +1,6 @@
use bytes::Buf;
use futures_core::stream::Stream;
+use futures_sink::Sink;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -165,7 +166,7 @@ where
B: Buf,
E: Into<std::io::Error>,
{
- /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead).
+ /// Convert a stream of byte chunks into an [`AsyncRead`].
///
/// The item should be a [`Result`] with the ok variant being something that
/// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
@@ -324,3 +325,22 @@ impl<S, B> StreamReader<S, B> {
}
}
}
+
+impl<S: Sink<T, Error = E>, E, T> Sink<T> for StreamReader<S, E> {
+ type Error = E;
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
diff --git a/src/io/sync_bridge.rs b/src/io/sync_bridge.rs
index f87bfbb..2402207 100644
--- a/src/io/sync_bridge.rs
+++ b/src/io/sync_bridge.rs
@@ -1,6 +1,7 @@
-use std::io::{BufRead, Read, Write};
+use std::io::{BufRead, Read, Seek, Write};
use tokio::io::{
- AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
+ AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite,
+ AsyncWriteExt,
};
/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
@@ -79,6 +80,13 @@ impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> {
}
}
+impl<T: AsyncSeek + Unpin> Seek for SyncIoBridge<T> {
+ fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
+ let src = &mut self.src;
+ self.rt.block_on(AsyncSeekExt::seek(src, pos))
+ }
+}
+
// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time
// of this writing still unstable, we expose this as part of a standalone method.
impl<T: AsyncWrite> SyncIoBridge<T> {
@@ -140,4 +148,9 @@ impl<T: Unpin> SyncIoBridge<T> {
pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self {
Self { src, rt }
}
+
+ /// Consume this bridge, returning the underlying stream.
+ pub fn into_inner(self) -> T {
+ self.src
+ }
}
diff --git a/src/lib.rs b/src/lib.rs
index 524fc47..22ad92b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -55,151 +55,6 @@ pub mod sync;
pub mod either;
-#[cfg(any(feature = "io", feature = "codec"))]
-mod util {
- use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+pub use bytes;
- use bytes::{Buf, BufMut};
- use futures_core::ready;
- use std::io::{self, IoSlice};
- use std::mem::MaybeUninit;
- use std::pin::Pin;
- use std::task::{Context, Poll};
-
- /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
- ///
- /// [`BufMut`]: bytes::Buf
- ///
- /// # Example
- ///
- /// ```
- /// use bytes::{Bytes, BytesMut};
- /// use tokio_stream as stream;
- /// use tokio::io::Result;
- /// use tokio_util::io::{StreamReader, poll_read_buf};
- /// use futures::future::poll_fn;
- /// use std::pin::Pin;
- /// # #[tokio::main]
- /// # async fn main() -> std::io::Result<()> {
- ///
- /// // Create a reader from an iterator. This particular reader will always be
- /// // ready.
- /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
- ///
- /// let mut buf = BytesMut::new();
- /// let mut reads = 0;
- ///
- /// loop {
- /// reads += 1;
- /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
- ///
- /// if n == 0 {
- /// break;
- /// }
- /// }
- ///
- /// // one or more reads might be necessary.
- /// assert!(reads >= 1);
- /// assert_eq!(&buf[..], &[0, 1, 2, 3]);
- /// # Ok(())
- /// # }
- /// ```
- #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
- pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
- io: Pin<&mut T>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- if !buf.has_remaining_mut() {
- return Poll::Ready(Ok(0));
- }
-
- let n = {
- let dst = buf.chunk_mut();
-
- // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
- // transparent wrapper around `[MaybeUninit<u8>]`.
- let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
- let mut buf = ReadBuf::uninit(dst);
- let ptr = buf.filled().as_ptr();
- ready!(io.poll_read(cx, &mut buf)?);
-
- // Ensure the pointer does not change from under us
- assert_eq!(ptr, buf.filled().as_ptr());
- buf.filled().len()
- };
-
- // Safety: This is guaranteed to be the number of initialized (and read)
- // bytes due to the invariants provided by `ReadBuf::filled`.
- unsafe {
- buf.advance_mut(n);
- }
-
- Poll::Ready(Ok(n))
- }
-
- /// Try to write data from an implementer of the [`Buf`] trait to an
- /// [`AsyncWrite`], advancing the buffer's internal cursor.
- ///
- /// This function will use [vectored writes] when the [`AsyncWrite`] supports
- /// vectored writes.
- ///
- /// # Examples
- ///
- /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
- /// [`Buf`]:
- ///
- /// ```no_run
- /// use tokio_util::io::poll_write_buf;
- /// use tokio::io;
- /// use tokio::fs::File;
- ///
- /// use bytes::Buf;
- /// use std::io::Cursor;
- /// use std::pin::Pin;
- /// use futures::future::poll_fn;
- ///
- /// #[tokio::main]
- /// async fn main() -> io::Result<()> {
- /// let mut file = File::create("foo.txt").await?;
- /// let mut buf = Cursor::new(b"data to write");
- ///
- /// // Loop until the entire contents of the buffer are written to
- /// // the file.
- /// while buf.has_remaining() {
- /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
- /// }
- ///
- /// Ok(())
- /// }
- /// ```
- ///
- /// [`Buf`]: bytes::Buf
- /// [`AsyncWrite`]: tokio::io::AsyncWrite
- /// [`File`]: tokio::fs::File
- /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
- #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
- pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
- io: Pin<&mut T>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- const MAX_BUFS: usize = 64;
-
- if !buf.has_remaining() {
- return Poll::Ready(Ok(0));
- }
-
- let n = if io.is_write_vectored() {
- let mut slices = [IoSlice::new(&[]); MAX_BUFS];
- let cnt = buf.chunks_vectored(&mut slices);
- ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
- } else {
- ready!(io.poll_write(cx, buf.chunk()))?
- };
-
- buf.advance(n);
-
- Poll::Ready(Ok(n))
- }
-}
+mod util;
diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs
index c44be69..5ef8ba2 100644
--- a/src/sync/cancellation_token.rs
+++ b/src/sync/cancellation_token.rs
@@ -4,6 +4,7 @@ pub(crate) mod guard;
mod tree_node;
use crate::loom::sync::Arc;
+use crate::util::MaybeDangling;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
@@ -77,11 +78,23 @@ pin_project! {
/// [`CancellationToken`] by value instead of using a reference.
#[must_use = "futures do nothing unless polled"]
pub struct WaitForCancellationFutureOwned {
- // Since `future` is the first field, it is dropped before the
- // cancellation_token field. This ensures that the reference inside the
- // `Notified` remains valid.
+ // This field internally has a reference to the cancellation token, but camouflages
+ // the relationship with `'static`. To avoid Undefined Behavior, we must ensure
+ // that the reference is only used while the cancellation token is still alive. To
+ // do that, we ensure that the future is the first field, so that it is dropped
+ // before the cancellation token.
+ //
+ // We use `MaybeDanglingFuture` here because without it, the compiler could assert
+ // the reference inside `future` to be valid even after the destructor of that
+ // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
+ // as an argument to a function, the reference can be asserted to be valid for the
+ // rest of that function.) To avoid that, we use `MaybeDangling` which tells the
+ // compiler that the reference stored inside it might not be valid.
+ //
+ // See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
+ // for more info.
#[pin]
- future: tokio::sync::futures::Notified<'static>,
+ future: MaybeDangling<tokio::sync::futures::Notified<'static>>,
cancellation_token: CancellationToken,
}
}
@@ -97,6 +110,8 @@ impl core::fmt::Debug for CancellationToken {
}
impl Clone for CancellationToken {
+ /// Creates a clone of the `CancellationToken` which will get cancelled
+ /// whenever the current token gets cancelled, and vice versa.
fn clone(&self) -> Self {
tree_node::increase_handle_refcount(&self.inner);
CancellationToken {
@@ -118,7 +133,7 @@ impl Default for CancellationToken {
}
impl CancellationToken {
- /// Creates a new CancellationToken in the non-cancelled state.
+ /// Creates a new `CancellationToken` in the non-cancelled state.
pub fn new() -> CancellationToken {
CancellationToken {
inner: Arc::new(tree_node::TreeNode::new()),
@@ -126,7 +141,8 @@ impl CancellationToken {
}
/// Creates a `CancellationToken` which will get cancelled whenever the
- /// current token gets cancelled.
+ /// current token gets cancelled. Unlike a cloned `CancellationToken`,
+ /// cancelling a child token does not cancel the parent token.
///
/// If the current token is already cancelled, the child token will get
/// returned in cancelled state.
@@ -276,7 +292,7 @@ impl WaitForCancellationFutureOwned {
// # Safety
//
// cancellation_token is dropped after future due to the field ordering.
- future: unsafe { Self::new_future(&cancellation_token) },
+ future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
cancellation_token,
}
}
@@ -317,8 +333,9 @@ impl Future for WaitForCancellationFutureOwned {
// # Safety
//
// cancellation_token is dropped after future due to the field ordering.
- this.future
- .set(unsafe { Self::new_future(this.cancellation_token) });
+ this.future.set(MaybeDangling::new(unsafe {
+ Self::new_future(this.cancellation_token)
+ }));
}
}
}
diff --git a/src/sync/cancellation_token/tree_node.rs b/src/sync/cancellation_token/tree_node.rs
index 8f97dee..b7a9805 100644
--- a/src/sync/cancellation_token/tree_node.rs
+++ b/src/sync/cancellation_token/tree_node.rs
@@ -1,12 +1,12 @@
//! This mod provides the logic for the inner tree structure of the CancellationToken.
//!
-//! CancellationTokens are only light handles with references to TreeNode.
-//! All the logic is actually implemented in the TreeNode.
+//! CancellationTokens are only light handles with references to [`TreeNode`].
+//! All the logic is actually implemented in the [`TreeNode`].
//!
-//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of
+//! A [`TreeNode`] is part of the cancellation tree and may have one parent and an arbitrary number of
//! children.
//!
-//! A TreeNode can receive the request to perform a cancellation through a CancellationToken.
+//! A [`TreeNode`] can receive the request to perform a cancellation through a CancellationToken.
//! This cancellation request will cancel the node and all of its descendants.
//!
//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no
@@ -151,47 +151,43 @@ fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret
where
F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret,
{
- let mut potential_parent = {
- let locked_node = node.inner.lock().unwrap();
- match locked_node.parent.clone() {
- Some(parent) => parent,
- // If we locked the node and its parent is `None`, we are in a valid state
- // and can return.
- None => return func(locked_node, None),
- }
- };
+ use std::sync::TryLockError;
+ let mut locked_node = node.inner.lock().unwrap();
+
+ // Every time this fails, the number of ancestors of the node decreases,
+ // so the loop must succeed after a finite number of iterations.
loop {
- // Deadlock safety:
- //
- // Due to invariant #2, we know that we have to lock the parent first, and then the child.
- // This is true even if the potential_parent is no longer the current parent or even its
- // sibling, as the invariant still holds.
- let locked_parent = potential_parent.inner.lock().unwrap();
- let locked_node = node.inner.lock().unwrap();
-
- let actual_parent = match locked_node.parent.clone() {
- Some(parent) => parent,
- // If we locked the node and its parent is `None`, we are in a valid state
- // and can return.
- None => {
- // Was the wrong parent, so unlock it before calling `func`
- drop(locked_parent);
- return func(locked_node, None);
+ // Look up the parent of the currently locked node.
+ let potential_parent = match locked_node.parent.as_ref() {
+ Some(potential_parent) => potential_parent.clone(),
+ None => return func(locked_node, None),
+ };
+
+ // Lock the parent. This may require unlocking the child first.
+ let locked_parent = match potential_parent.inner.try_lock() {
+ Ok(locked_parent) => locked_parent,
+ Err(TryLockError::WouldBlock) => {
+ drop(locked_node);
+ // Deadlock safety:
+ //
+ // Due to invariant #2, the potential parent must come before
+ // the child in the creation order. Therefore, we can safely
+ // lock the child while holding the parent lock.
+ let locked_parent = potential_parent.inner.lock().unwrap();
+ locked_node = node.inner.lock().unwrap();
+ locked_parent
}
+ Err(TryLockError::Poisoned(err)) => Err(err).unwrap(),
};
- // Loop until we managed to lock both the node and its parent
- if Arc::ptr_eq(&actual_parent, &potential_parent) {
- return func(locked_node, Some(locked_parent));
+ // If we unlocked the child, then the parent may have changed. Check
+ // that we still have the right parent.
+ if let Some(actual_parent) = locked_node.parent.as_ref() {
+ if Arc::ptr_eq(actual_parent, &potential_parent) {
+ return func(locked_node, Some(locked_parent));
+ }
}
-
- // Drop locked_parent before reassigning to potential_parent,
- // as potential_parent is borrowed in it
- drop(locked_node);
- drop(locked_parent);
-
- potential_parent = actual_parent;
}
}
@@ -243,11 +239,7 @@ fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) {
let len = parent.children.len();
if 4 * len <= parent.children.capacity() {
- // equal to:
- // parent.children.shrink_to(2 * len);
- // but shrink_to was not yet stabilized in our minimal compatible version
- let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len));
- parent.children.extend(old_children);
+ parent.children.shrink_to(2 * len);
}
}
diff --git a/src/sync/mpsc.rs b/src/sync/mpsc.rs
index 55ed5c4..fd48c72 100644
--- a/src/sync/mpsc.rs
+++ b/src/sync/mpsc.rs
@@ -44,7 +44,7 @@ enum State<T> {
pub struct PollSender<T> {
sender: Option<Sender<T>>,
state: State<T>,
- acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>,
+ acquire: PollSenderFuture<T>,
}
// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
@@ -64,13 +64,56 @@ async fn make_acquire_future<T>(
}
}
-impl<T: Send + 'static> PollSender<T> {
+type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;
+
+#[derive(Debug)]
+// TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes
+struct PollSenderFuture<T>(InnerFuture<'static, T>);
+
+impl<T> PollSenderFuture<T> {
+ /// Create with an empty inner future with no `Send` bound.
+ fn empty() -> Self {
+ // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
+ // compatible with the transitive bounds required by `Sender<T>`.
+ Self(ReusableBoxFuture::new(async { unreachable!() }))
+ }
+}
+
+impl<T: Send> PollSenderFuture<T> {
+ /// Create with an empty inner future.
+ fn new() -> Self {
+ let v = InnerFuture::new(make_acquire_future(None));
+ // This is safe because `make_acquire_future(None)` is actually `'static`
+ Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
+ }
+
+ /// Poll the inner future.
+ fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
+ self.0.poll(cx)
+ }
+
+ /// Replace the inner future.
+ fn set(&mut self, sender: Option<Sender<T>>) {
+ let inner: *mut InnerFuture<'static, T> = &mut self.0;
+ let inner: *mut InnerFuture<'_, T> = inner.cast();
+ // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T`
+ // becomes invalid, and this casts away the type-level lifetime check for that. However, the
+ // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not
+ // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed
+ // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so
+ // this is ok.
+ let inner = unsafe { &mut *inner };
+ inner.set(make_acquire_future(sender));
+ }
+}
+
+impl<T: Send> PollSender<T> {
/// Creates a new `PollSender`.
pub fn new(sender: Sender<T>) -> Self {
Self {
sender: Some(sender.clone()),
state: State::Idle(sender),
- acquire: ReusableBoxFuture::new(make_acquire_future(None)),
+ acquire: PollSenderFuture::new(),
}
}
@@ -97,7 +140,7 @@ impl<T: Send + 'static> PollSender<T> {
State::Idle(sender) => {
// Start trying to acquire a permit to reserve a slot for our send, and
// immediately loop back around to poll it the first time.
- self.acquire.set(make_acquire_future(Some(sender)));
+ self.acquire.set(Some(sender));
(None, State::Acquiring)
}
State::Acquiring => match self.acquire.poll(cx) {
@@ -194,7 +237,7 @@ impl<T: Send + 'static> PollSender<T> {
match self.state {
State::Idle(_) => self.state = State::Closed,
State::Acquiring => {
- self.acquire.set(make_acquire_future(None));
+ self.acquire.set(None);
self.state = State::Closed;
}
_ => {}
@@ -215,7 +258,7 @@ impl<T: Send + 'static> PollSender<T> {
// We're currently trying to reserve a slot to send into.
State::Acquiring => {
// Replacing the future drops the in-flight one.
- self.acquire.set(make_acquire_future(None));
+ self.acquire.set(None);
// If we haven't closed yet, we have to clone our stored sender since we have no way
// to get it back from the acquire future we just dropped.
@@ -255,9 +298,7 @@ impl<T> Clone for PollSender<T> {
Self {
sender,
state,
- // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
- // compatible with the transitive bounds required by `Sender<T>`.
- acquire: ReusableBoxFuture::new(async { unreachable!() }),
+ acquire: PollSenderFuture::empty(),
}
}
}
diff --git a/src/sync/poll_semaphore.rs b/src/sync/poll_semaphore.rs
index 6b44574..4960a7c 100644
--- a/src/sync/poll_semaphore.rs
+++ b/src/sync/poll_semaphore.rs
@@ -29,7 +29,7 @@ impl PollSemaphore {
/// Closes the semaphore.
pub fn close(&self) {
- self.semaphore.close()
+ self.semaphore.close();
}
/// Obtain a clone of the inner semaphore.
diff --git a/src/sync/reusable_box.rs b/src/sync/reusable_box.rs
index 1b8ef60..1fae38c 100644
--- a/src/sync/reusable_box.rs
+++ b/src/sync/reusable_box.rs
@@ -1,7 +1,6 @@
use std::alloc::Layout;
use std::fmt;
-use std::future::Future;
-use std::marker::PhantomData;
+use std::future::{self, Future};
use std::mem::{self, ManuallyDrop};
use std::pin::Pin;
use std::ptr;
@@ -61,7 +60,7 @@ impl<'a, T> ReusableBoxFuture<'a, T> {
F: Future + Send + 'a,
{
// future::Pending<T> is a ZST so this never allocates.
- let boxed = mem::replace(&mut this.boxed, Box::pin(Pending(PhantomData)));
+ let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending()));
reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed))
}
@@ -156,16 +155,3 @@ impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> {
f();
}
}
-
-/// The same as `std::future::Pending<T>`; we can't use that type directly because on rustc
-/// versions <1.60 it didn't unconditionally implement `Send`.
-// FIXME: use `std::future::Pending<T>` once the MSRV is >=1.60
-struct Pending<T>(PhantomData<fn() -> T>);
-
-impl<T> Future for Pending<T> {
- type Output = T;
-
- fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
- Poll::Pending
- }
-}
diff --git a/src/task/join_map.rs b/src/task/join_map.rs
index c6bf5bc..1fbe274 100644
--- a/src/task/join_map.rs
+++ b/src/task/join_map.rs
@@ -5,6 +5,7 @@ use std::collections::hash_map::RandomState;
use std::fmt;
use std::future::Future;
use std::hash::{BuildHasher, Hash, Hasher};
+use std::marker::PhantomData;
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
@@ -316,6 +317,60 @@ where
self.insert(key, task);
}
+ /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
+ /// key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// Note that blocking tasks cannot be cancelled after execution starts.
+ /// Replaced blocking tasks will still run to completion if the task has begun
+ /// to execute when it is replaced. A blocking task which is replaced before
+ /// it has been scheduled on a blocking worker thread will be cancelled.
+ ///
+ /// # Panics
+ ///
+ /// This method panics if called outside of a Tokio runtime.
+ ///
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn_blocking<F>(&mut self, key: K, f: F)
+ where
+ F: FnOnce() -> V,
+ F: Send + 'static,
+ V: Send,
+ {
+ let task = self.tasks.spawn_blocking(f);
+ self.insert(key, task)
+ }
+
+ /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
+ /// `JoinMap` with the provided key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// Note that blocking tasks cannot be cancelled after execution starts.
+ /// Replaced blocking tasks will still run to completion if the task has begun
+ /// to execute when it is replaced. A blocking task which is replaced before
+ /// it has been scheduled on a blocking worker thread will be cancelled.
+ ///
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
+ where
+ F: FnOnce() -> V,
+ F: Send + 'static,
+ V: Send,
+ {
+ let task = self.tasks.spawn_blocking_on(f, handle);
+ self.insert(key, task);
+ }
+
/// Spawn the provided task on the current [`LocalSet`] and store it in this
/// `JoinMap` with the provided key.
///
@@ -572,6 +627,19 @@ where
}
}
+ /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
+ ///
+ /// If a task has completed, but its output hasn't yet been consumed by a
+ /// call to [`join_next`], this method will still return its key.
+ ///
+ /// [`join_next`]: fn@Self::join_next
+ pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
+ JoinMapKeys {
+ iter: self.tasks_by_key.keys(),
+ _value: PhantomData,
+ }
+ }
+
/// Returns `true` if this `JoinMap` contains a task for the provided key.
///
/// If the task has completed, but its output hasn't yet been consumed by a
@@ -805,3 +873,32 @@ impl<K: PartialEq> PartialEq for Key<K> {
}
impl<K: Eq> Eq for Key<K> {}
+
+/// An iterator over the keys of a [`JoinMap`].
+#[derive(Debug, Clone)]
+pub struct JoinMapKeys<'a, K, V> {
+ iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
+ /// To make it easier to change JoinMap in the future, keep V as a generic
+ /// parameter.
+ _value: PhantomData<&'a V>,
+}
+
+impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
+ type Item = &'a K;
+
+ fn next(&mut self) -> Option<&'a K> {
+ self.iter.next().map(|key| &key.key)
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.iter.size_hint()
+ }
+}
+
+impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
+ fn len(&self) -> usize {
+ self.iter.len()
+ }
+}
+
+impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}
diff --git a/src/task/mod.rs b/src/task/mod.rs
index de41dd5..e37015a 100644
--- a/src/task/mod.rs
+++ b/src/task/mod.rs
@@ -9,4 +9,7 @@ pub use spawn_pinned::LocalPoolHandle;
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))]
-pub use join_map::JoinMap;
+pub use join_map::{JoinMap, JoinMapKeys};
+
+pub mod task_tracker;
+pub use task_tracker::TaskTracker;
diff --git a/src/task/task_tracker.rs b/src/task/task_tracker.rs
new file mode 100644
index 0000000..d8f3bb4
--- /dev/null
+++ b/src/task/task_tracker.rs
@@ -0,0 +1,719 @@
+//! Types related to the [`TaskTracker`] collection.
+//!
+//! See the documentation of [`TaskTracker`] for more information.
+
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use tokio::sync::{futures::Notified, Notify};
+
+#[cfg(feature = "rt")]
+use tokio::{
+ runtime::Handle,
+ task::{JoinHandle, LocalSet},
+};
+
+/// A task tracker used for waiting until tasks exit.
+///
+/// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The
+/// `CancellationToken` is used to signal to tasks that they should shut down, and the
+/// `TaskTracker` is used to wait for them to finish shutting down.
+///
+/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case
+/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the
+/// [`wait`] method will wait until *both* of the following happen at the same time:
+///
+/// * The `TaskTracker` must be closed using the [`close`] method.
+/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited.
+///
+/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that
+/// the destructor of the future has finished running. However, there might be a short amount of
+/// time where [`JoinHandle::is_finished`] returns false.
+///
+/// # Comparison to `JoinSet`
+///
+/// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a
+/// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its
+/// unique features is required:
+///
+/// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory.
+/// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from from returning even if
+/// the `TaskTracker` is empty.
+/// 3. A `TaskTracker` does not require mutable access to insert tasks.
+/// 4. A `TaskTracker` can be cloned to share it with many tasks.
+///
+/// The first point is the most important one. A [`JoinSet`] keeps track of the return value of
+/// every inserted task. This means that if the caller keeps inserting tasks and never calls
+/// [`join_next`], then their return values will keep building up and consuming memory, _even if_
+/// most of the tasks have already exited. This can cause the process to run out of memory. With a
+/// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the
+/// `TaskTracker`.
+///
+/// # Examples
+///
+/// For more examples, please see the topic page on [graceful shutdown].
+///
+/// ## Spawn tasks and wait for them to exit
+///
+/// This is a simple example. For this case, [`JoinSet`] should probably be used instead.
+///
+/// ```
+/// use tokio_util::task::TaskTracker;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let tracker = TaskTracker::new();
+///
+/// for i in 0..10 {
+/// tracker.spawn(async move {
+/// println!("Task {} is running!", i);
+/// });
+/// }
+/// // Once we spawned everything, we close the tracker.
+/// tracker.close();
+///
+/// // Wait for everything to finish.
+/// tracker.wait().await;
+///
+/// println!("This is printed after all of the tasks.");
+/// }
+/// ```
+///
+/// ## Wait for tasks to exit
+///
+/// This example shows the intended use-case of `TaskTracker`. It is used together with
+/// [`CancellationToken`] to implement graceful shutdown.
+/// ```
+/// use tokio_util::sync::CancellationToken;
+/// use tokio_util::task::TaskTracker;
+/// use tokio::time::{self, Duration};
+///
+/// async fn background_task(num: u64) {
+/// for i in 0..10 {
+/// time::sleep(Duration::from_millis(100*num)).await;
+/// println!("Background task {} in iteration {}.", num, i);
+/// }
+/// }
+///
+/// #[tokio::main]
+/// # async fn _hidden() {}
+/// # #[tokio::main(flavor = "current_thread", start_paused = true)]
+/// async fn main() {
+/// let tracker = TaskTracker::new();
+/// let token = CancellationToken::new();
+///
+/// for i in 0..10 {
+/// let token = token.clone();
+/// tracker.spawn(async move {
+/// // Use a `tokio::select!` to kill the background task if the token is
+/// // cancelled.
+/// tokio::select! {
+/// () = background_task(i) => {
+/// println!("Task {} exiting normally.", i);
+/// },
+/// () = token.cancelled() => {
+/// // Do some cleanup before we really exit.
+/// time::sleep(Duration::from_millis(50)).await;
+/// println!("Task {} finished cleanup.", i);
+/// },
+/// }
+/// });
+/// }
+///
+/// // Spawn a background task that will send the shutdown signal.
+/// {
+/// let tracker = tracker.clone();
+/// tokio::spawn(async move {
+/// // Normally you would use something like ctrl-c instead of
+/// // sleeping.
+/// time::sleep(Duration::from_secs(2)).await;
+/// tracker.close();
+/// token.cancel();
+/// });
+/// }
+///
+/// // Wait for all tasks to exit.
+/// tracker.wait().await;
+///
+/// println!("All tasks have exited now.");
+/// }
+/// ```
+///
+/// [`CancellationToken`]: crate::sync::CancellationToken
+/// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished
+/// [`JoinSet`]: tokio::task::JoinSet
+/// [`close`]: Self::close
+/// [`join_next`]: tokio::task::JoinSet::join_next
+/// [`wait`]: Self::wait
+/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown
+pub struct TaskTracker {
+ inner: Arc<TaskTrackerInner>,
+}
+
+/// Represents a task tracked by a [`TaskTracker`].
+#[must_use]
+#[derive(Debug)]
+pub struct TaskTrackerToken {
+ task_tracker: TaskTracker,
+}
+
+struct TaskTrackerInner {
+ /// Keeps track of the state.
+ ///
+ /// The lowest bit is whether the task tracker is closed.
+ ///
+ /// The rest of the bits count the number of tracked tasks.
+ state: AtomicUsize,
+ /// Used to notify when the last task exits.
+ on_last_exit: Notify,
+}
+
+pin_project! {
+ /// A future that is tracked as a task by a [`TaskTracker`].
+ ///
+ /// The associated [`TaskTracker`] cannot complete until this future is dropped.
+ ///
+ /// This future is returned by [`TaskTracker::track_future`].
+ #[must_use = "futures do nothing unless polled"]
+ pub struct TrackedFuture<F> {
+ #[pin]
+ future: F,
+ token: TaskTrackerToken,
+ }
+}
+
+pin_project! {
+ /// A future that completes when the [`TaskTracker`] is empty and closed.
+ ///
+ /// This future is returned by [`TaskTracker::wait`].
+ #[must_use = "futures do nothing unless polled"]
+ pub struct TaskTrackerWaitFuture<'a> {
+ #[pin]
+ future: Notified<'a>,
+ inner: Option<&'a TaskTrackerInner>,
+ }
+}
+
+impl TaskTrackerInner {
+ #[inline]
+ fn new() -> Self {
+ Self {
+ state: AtomicUsize::new(0),
+ on_last_exit: Notify::new(),
+ }
+ }
+
+ #[inline]
+ fn is_closed_and_empty(&self) -> bool {
+ // If empty and closed bit set, then we are done.
+ //
+ // The acquire load will synchronize with the release store of any previous call to
+ // `set_closed` and `drop_task`.
+ self.state.load(Ordering::Acquire) == 1
+ }
+
+ #[inline]
+ fn set_closed(&self) -> bool {
+ // The AcqRel ordering makes the closed bit behave like a `Mutex<bool>` for synchronization
+ // purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}`
+ // more meaningful for the user. Without these orderings, this assert could fail:
+ // ```
+ // // thread 1
+ // some_other_atomic.store(true, Relaxed);
+ // tracker.close();
+ //
+ // // thread 2
+ // if tracker.reopen() {
+ // assert!(some_other_atomic.load(Relaxed));
+ // }
+ // ```
+ // However, with the AcqRel ordering, we establish a happens-before relationship from the
+ // call to `close` and the later call to `reopen` that returned true.
+ let state = self.state.fetch_or(1, Ordering::AcqRel);
+
+ // If there are no tasks, and if it was not already closed:
+ if state == 0 {
+ self.notify_now();
+ }
+
+ (state & 1) == 0
+ }
+
+ #[inline]
+ fn set_open(&self) -> bool {
+ // See `set_closed` regarding the AcqRel ordering.
+ let state = self.state.fetch_and(!1, Ordering::AcqRel);
+ (state & 1) == 1
+ }
+
+ #[inline]
+ fn add_task(&self) {
+ self.state.fetch_add(2, Ordering::Relaxed);
+ }
+
+ #[inline]
+ fn drop_task(&self) {
+ let state = self.state.fetch_sub(2, Ordering::Release);
+
+ // If this was the last task and we are closed:
+ if state == 3 {
+ self.notify_now();
+ }
+ }
+
+ #[cold]
+ fn notify_now(&self) {
+ // Insert an acquire fence. This matters for `drop_task` but doesn't matter for
+ // `set_closed` since it already uses AcqRel.
+ //
+ // This synchronizes with the release store of any other call to `drop_task`, and with the
+ // release store in the call to `set_closed`. That ensures that everything that happened
+ // before those other calls to `drop_task` or `set_closed` will be visible after this load,
+ // and those things will also be visible to anything woken by the call to `notify_waiters`.
+ self.state.load(Ordering::Acquire);
+
+ self.on_last_exit.notify_waiters();
+ }
+}
+
+impl TaskTracker {
+ /// Creates a new `TaskTracker`.
+ ///
+ /// The `TaskTracker` will start out as open.
+ #[must_use]
+ pub fn new() -> Self {
+ Self {
+ inner: Arc::new(TaskTrackerInner::new()),
+ }
+ }
+
+ /// Waits until this `TaskTracker` is both closed and empty.
+ ///
+ /// If the `TaskTracker` is already closed and empty when this method is called, then it
+ /// returns immediately.
+ ///
+ /// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker`
+ /// becomes both closed and empty for a short amount of time, then it is guarantee that all
+ /// `wait` futures that were created before the short time interval will trigger, even if they
+ /// are not polled during that short time interval.
+ ///
+ /// # Cancel safety
+ ///
+ /// This method is cancel safe.
+ ///
+ /// However, the resistance against [ABA problems][aba] is lost when using `wait` as the
+ /// condition in a `tokio::select!` loop.
+ ///
+ /// [aba]: https://en.wikipedia.org/wiki/ABA_problem
+ #[inline]
+ pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
+ TaskTrackerWaitFuture {
+ future: self.inner.on_last_exit.notified(),
+ inner: if self.inner.is_closed_and_empty() {
+ None
+ } else {
+ Some(&self.inner)
+ },
+ }
+ }
+
+ /// Close this `TaskTracker`.
+ ///
+ /// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks.
+ ///
+ /// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed.
+ ///
+ /// [`wait`]: Self::wait
+ #[inline]
+ pub fn close(&self) -> bool {
+ self.inner.set_closed()
+ }
+
+ /// Reopen this `TaskTracker`.
+ ///
+ /// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty.
+ ///
+ /// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open.
+ ///
+ /// [`wait`]: Self::wait
+ #[inline]
+ pub fn reopen(&self) -> bool {
+ self.inner.set_open()
+ }
+
+ /// Returns `true` if this `TaskTracker` is [closed](Self::close).
+ #[inline]
+ #[must_use]
+ pub fn is_closed(&self) -> bool {
+ (self.inner.state.load(Ordering::Acquire) & 1) != 0
+ }
+
+ /// Returns the number of tasks tracked by this `TaskTracker`.
+ #[inline]
+ #[must_use]
+ pub fn len(&self) -> usize {
+ self.inner.state.load(Ordering::Acquire) >> 1
+ }
+
+ /// Returns `true` if there are no tasks in this `TaskTracker`.
+ #[inline]
+ #[must_use]
+ pub fn is_empty(&self) -> bool {
+ self.inner.state.load(Ordering::Acquire) <= 1
+ }
+
+ /// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `tokio::spawn(tracker.track_future(task))`.
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
+ where
+ F: Future + Send + 'static,
+ F::Output: Send + 'static,
+ {
+ tokio::task::spawn(self.track_future(task))
+ }
+
+ /// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `handle.spawn(tracker.track_future(task))`.
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn_on<F>(&self, task: F, handle: &Handle) -> JoinHandle<F::Output>
+ where
+ F: Future + Send + 'static,
+ F::Output: Send + 'static,
+ {
+ handle.spawn(self.track_future(task))
+ }
+
+ /// Spawn the provided future on the current [`LocalSet`], and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`.
+ ///
+ /// [`LocalSet`]: tokio::task::LocalSet
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
+ where
+ F: Future + 'static,
+ F::Output: 'static,
+ {
+ tokio::task::spawn_local(self.track_future(task))
+ }
+
+ /// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`.
+ ///
+ /// [`LocalSet`]: tokio::task::LocalSet
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn_local_on<F>(&self, task: F, local_set: &LocalSet) -> JoinHandle<F::Output>
+ where
+ F: Future + 'static,
+ F::Output: 'static,
+ {
+ local_set.spawn_local(self.track_future(task))
+ }
+
+ /// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`.
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg(not(target_family = "wasm"))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
+ where
+ F: FnOnce() -> T,
+ F: Send + 'static,
+ T: Send + 'static,
+ {
+ let token = self.token();
+ tokio::task::spawn_blocking(move || {
+ let res = task();
+ drop(token);
+ res
+ })
+ }
+
+ /// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`.
+ ///
+ /// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`.
+ #[inline]
+ #[track_caller]
+ #[cfg(feature = "rt")]
+ #[cfg(not(target_family = "wasm"))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
+ where
+ F: FnOnce() -> T,
+ F: Send + 'static,
+ T: Send + 'static,
+ {
+ let token = self.token();
+ handle.spawn_blocking(move || {
+ let res = task();
+ drop(token);
+ res
+ })
+ }
+
+ /// Track the provided future.
+ ///
+ /// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will
+ /// prevent calls to [`wait`] from returning until the task is dropped.
+ ///
+ /// The task is removed from the collection when it is dropped, not when [`poll`] returns
+ /// [`Poll::Ready`].
+ ///
+ /// # Examples
+ ///
+ /// Track a future spawned with [`tokio::spawn`].
+ ///
+ /// ```
+ /// # async fn my_async_fn() {}
+ /// use tokio_util::task::TaskTracker;
+ ///
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// # async fn main() {
+ /// let tracker = TaskTracker::new();
+ ///
+ /// tokio::spawn(tracker.track_future(my_async_fn()));
+ /// # }
+ /// ```
+ ///
+ /// Track a future spawned on a [`JoinSet`].
+ /// ```
+ /// # async fn my_async_fn() {}
+ /// use tokio::task::JoinSet;
+ /// use tokio_util::task::TaskTracker;
+ ///
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// # async fn main() {
+ /// let tracker = TaskTracker::new();
+ /// let mut join_set = JoinSet::new();
+ ///
+ /// join_set.spawn(tracker.track_future(my_async_fn()));
+ /// # }
+ /// ```
+ ///
+ /// [`JoinSet`]: tokio::task::JoinSet
+ /// [`Poll::Pending`]: std::task::Poll::Pending
+ /// [`poll`]: std::future::Future::poll
+ /// [`wait`]: Self::wait
+ #[inline]
+ pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> {
+ TrackedFuture {
+ future,
+ token: self.token(),
+ }
+ }
+
+ /// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`.
+ ///
+ /// This token is a lower-level utility than the spawn methods. Each token is considered to
+ /// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete.
+ /// Furthermore, the count returned by the [`len`] method will include the tokens in the count.
+ ///
+ /// Dropping the token indicates to the `TaskTracker` that the task has exited.
+ ///
+ /// [`len`]: TaskTracker::len
+ #[inline]
+ pub fn token(&self) -> TaskTrackerToken {
+ self.inner.add_task();
+ TaskTrackerToken {
+ task_tracker: self.clone(),
+ }
+ }
+
+ /// Returns `true` if both task trackers correspond to the same set of tasks.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::TaskTracker;
+ ///
+ /// let tracker_1 = TaskTracker::new();
+ /// let tracker_2 = TaskTracker::new();
+ /// let tracker_1_clone = tracker_1.clone();
+ ///
+ /// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone));
+ /// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2));
+ /// ```
+ #[inline]
+ #[must_use]
+ pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool {
+ Arc::ptr_eq(&left.inner, &right.inner)
+ }
+}
+
+impl Default for TaskTracker {
+ /// Creates a new `TaskTracker`.
+ ///
+ /// The `TaskTracker` will start out as open.
+ #[inline]
+ fn default() -> TaskTracker {
+ TaskTracker::new()
+ }
+}
+
+impl Clone for TaskTracker {
+ /// Returns a new `TaskTracker` that tracks the same set of tasks.
+ ///
+ /// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in
+ /// all other clones.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::TaskTracker;
+ ///
+ /// #[tokio::main]
+ /// # async fn _hidden() {}
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// async fn main() {
+ /// let tracker = TaskTracker::new();
+ /// let cloned = tracker.clone();
+ ///
+ /// // Spawns on `tracker` are visible in `cloned`.
+ /// tracker.spawn(std::future::pending::<()>());
+ /// assert_eq!(cloned.len(), 1);
+ ///
+ /// // Spawns on `cloned` are visible in `tracker`.
+ /// cloned.spawn(std::future::pending::<()>());
+ /// assert_eq!(tracker.len(), 2);
+ ///
+ /// // Calling `close` is visible to `cloned`.
+ /// tracker.close();
+ /// assert!(cloned.is_closed());
+ ///
+ /// // Calling `reopen` is visible to `tracker`.
+ /// cloned.reopen();
+ /// assert!(!tracker.is_closed());
+ /// }
+ /// ```
+ #[inline]
+ fn clone(&self) -> TaskTracker {
+ Self {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let state = inner.state.load(Ordering::Acquire);
+ let is_closed = (state & 1) != 0;
+ let len = state >> 1;
+
+ f.debug_struct("TaskTracker")
+ .field("len", &len)
+ .field("is_closed", &is_closed)
+ .field("inner", &(inner as *const TaskTrackerInner))
+ .finish()
+}
+
+impl fmt::Debug for TaskTracker {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ debug_inner(&self.inner, f)
+ }
+}
+
+impl TaskTrackerToken {
+ /// Returns the [`TaskTracker`] that this token is associated with.
+ #[inline]
+ #[must_use]
+ pub fn task_tracker(&self) -> &TaskTracker {
+ &self.task_tracker
+ }
+}
+
+impl Clone for TaskTrackerToken {
+ /// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`].
+ ///
+ /// This is equivalent to `token.task_tracker().token()`.
+ #[inline]
+ fn clone(&self) -> TaskTrackerToken {
+ self.task_tracker.token()
+ }
+}
+
+impl Drop for TaskTrackerToken {
+ /// Dropping the token indicates to the [`TaskTracker`] that the task has exited.
+ #[inline]
+ fn drop(&mut self) {
+ self.task_tracker.inner.drop_task();
+ }
+}
+
+impl<F: Future> Future for TrackedFuture<F> {
+ type Output = F::Output;
+
+ #[inline]
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
+ self.project().future.poll(cx)
+ }
+}
+
+impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("TrackedFuture")
+ .field("future", &self.future)
+ .field("task_tracker", self.token.task_tracker())
+ .finish()
+ }
+}
+
+impl<'a> Future for TaskTrackerWaitFuture<'a> {
+ type Output = ();
+
+ #[inline]
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+ let me = self.project();
+
+ let inner = match me.inner.as_ref() {
+ None => return Poll::Ready(()),
+ Some(inner) => inner,
+ };
+
+ let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready();
+ if ready {
+ *me.inner = None;
+ Poll::Ready(())
+ } else {
+ Poll::Pending
+ }
+ }
+}
+
+impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ struct Helper<'a>(&'a TaskTrackerInner);
+
+ impl fmt::Debug for Helper<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ debug_inner(self.0, f)
+ }
+ }
+
+ f.debug_struct("TaskTrackerWaitFuture")
+ .field("future", &self.future)
+ .field("task_tracker", &self.inner.map(Helper))
+ .finish()
+ }
+}
diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs
index ee66adb..9136d90 100644
--- a/src/time/delay_queue.rs
+++ b/src/time/delay_queue.rs
@@ -62,7 +62,7 @@ use std::task::{self, Poll, Waker};
/// performance and scalability benefits.
///
/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation,
-/// and allows reuse of the memory allocated for expired entires.
+/// and allows reuse of the memory allocated for expired entries.
///
/// Capacity can be checked using [`capacity`] and allocated preemptively by using
/// the [`reserve`] method.
@@ -874,6 +874,41 @@ impl<T> DelayQueue<T> {
self.slab.compact();
}
+ /// Gets the [`Key`] that [`poll_expired`] will pull out of the queue next, without
+ /// pulling it out or waiting for the deadline to expire.
+ ///
+ /// Entries that have already expired may be returned in any order, but it is
+ /// guaranteed that this method returns them in the same order as when items
+ /// are popped from the `DelayQueue`.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ ///
+ /// let key1 = delay_queue.insert("foo", Duration::from_secs(10));
+ /// let key2 = delay_queue.insert("bar", Duration::from_secs(5));
+ /// let key3 = delay_queue.insert("baz", Duration::from_secs(15));
+ ///
+ /// assert_eq!(delay_queue.peek().unwrap(), key2);
+ /// # }
+ /// ```
+ ///
+ /// [`Key`]: struct@Key
+ /// [`poll_expired`]: method@Self::poll_expired
+ pub fn peek(&self) -> Option<Key> {
+ use self::wheel::Stack;
+
+ self.expired.peek().or_else(|| self.wheel.peek())
+ }
+
/// Returns the next time to poll as determined by the wheel
fn next_deadline(&mut self) -> Option<Instant> {
self.wheel
@@ -1166,6 +1201,10 @@ impl<T> wheel::Stack for Stack<T> {
}
}
+ fn peek(&self) -> Option<Self::Owned> {
+ self.head
+ }
+
#[track_caller]
fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) {
let key = *item;
diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs
index 8ea30af..4290acf 100644
--- a/src/time/wheel/level.rs
+++ b/src/time/wheel/level.rs
@@ -140,11 +140,31 @@ impl<T: Stack> Level<T> {
// TODO: This can probably be simplified w/ power of 2 math
let level_start = now - (now % level_range);
- let deadline = level_start + slot as u64 * slot_range;
-
+ let mut deadline = level_start + slot as u64 * slot_range;
+ if deadline < now {
+ // A timer is in a slot "prior" to the current time. This can occur
+ // because we do not have an infinite hierarchy of timer levels, and
+ // eventually a timer scheduled for a very distant time might end up
+ // being placed in a slot that is beyond the end of all of the
+ // arrays.
+ //
+ // To deal with this, we first limit timers to being scheduled no
+ // more than MAX_DURATION ticks in the future; that is, they're at
+ // most one rotation of the top level away. Then, we force timers
+ // that logically would go into the top+1 level, to instead go into
+ // the top level's slots.
+ //
+ // What this means is that the top level's slots act as a
+ // pseudo-ring buffer, and we rotate around them indefinitely. If we
+ // compute a deadline before now, and it's the top level, it
+ // therefore means we're actually looking at a slot in the future.
+ debug_assert_eq!(self.level, super::NUM_LEVELS - 1);
+
+ deadline += level_range;
+ }
debug_assert!(
deadline >= now,
- "deadline={}; now={}; level={}; slot={}; occupied={:b}",
+ "deadline={:016X}; now={:016X}; level={}; slot={}; occupied={:b}",
deadline,
now,
self.level,
@@ -206,6 +226,10 @@ impl<T: Stack> Level<T> {
ret
}
+
+ pub(crate) fn peek_entry_slot(&self, slot: usize) -> Option<T::Owned> {
+ self.slot[slot].peek()
+ }
}
impl<T> fmt::Debug for Level<T> {
diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs
index ffa05ab..10a9900 100644
--- a/src/time/wheel/mod.rs
+++ b/src/time/wheel/mod.rs
@@ -139,6 +139,12 @@ where
self.next_expiration().map(|expiration| expiration.deadline)
}
+ /// Next key that will expire
+ pub(crate) fn peek(&self) -> Option<T::Owned> {
+ self.next_expiration()
+ .and_then(|expiration| self.peek_entry(&expiration))
+ }
+
/// Advances the timer up to the instant represented by `now`.
pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> {
loop {
@@ -244,6 +250,10 @@ where
self.levels[expiration.level].pop_entry_slot(expiration.slot, store)
}
+ fn peek_entry(&self, expiration: &Expiration) -> Option<T::Owned> {
+ self.levels[expiration.level].peek_entry_slot(expiration.slot)
+ }
+
fn level_for(&self, when: u64) -> usize {
level_for(self.elapsed, when)
}
@@ -254,8 +264,11 @@ fn level_for(elapsed: u64, when: u64) -> usize {
// Mask in the trailing bits ignored by the level calculation in order to cap
// the possible leading zeros
- let masked = elapsed ^ when | SLOT_MASK;
-
+ let mut masked = elapsed ^ when | SLOT_MASK;
+ if masked >= MAX_DURATION {
+ // Fudge the timer into the top level
+ masked = MAX_DURATION - 1;
+ }
let leading_zeros = masked.leading_zeros() as usize;
let significant = 63 - leading_zeros;
significant / 6
diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs
index c87adca..7d32f27 100644
--- a/src/time/wheel/stack.rs
+++ b/src/time/wheel/stack.rs
@@ -22,6 +22,9 @@ pub(crate) trait Stack: Default {
/// Pop an item from the stack
fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>;
+ /// Peek into the stack.
+ fn peek(&self) -> Option<Self::Owned>;
+
fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store);
fn when(item: &Self::Borrowed, store: &Self::Store) -> u64;
diff --git a/src/util/maybe_dangling.rs b/src/util/maybe_dangling.rs
new file mode 100644
index 0000000..c29a089
--- /dev/null
+++ b/src/util/maybe_dangling.rs
@@ -0,0 +1,67 @@
+use core::future::Future;
+use core::mem::MaybeUninit;
+use core::pin::Pin;
+use core::task::{Context, Poll};
+
+/// A wrapper type that tells the compiler that the contents might not be valid.
+///
+/// This is necessary mainly when `T` contains a reference. In that case, the
+/// compiler will sometimes assume that the reference is always valid; in some
+/// cases it will assume this even after the destructor of `T` runs. For
+/// example, when a reference is used as a function argument, then the compiler
+/// will assume that the reference is valid until the function returns, even if
+/// the reference is destroyed during the function. When the reference is used
+/// as part of a self-referential struct, that assumption can be false. Wrapping
+/// the reference in this type prevents the compiler from making that
+/// assumption.
+///
+/// # Invariants
+///
+/// The `MaybeUninit` will always contain a valid value until the destructor runs.
+//
+// Reference
+// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
+//
+// TODO: replace this with an official solution once RFC #3336 or similar is available.
+// <https://github.com/rust-lang/rfcs/pull/3336>
+#[repr(transparent)]
+pub(crate) struct MaybeDangling<T>(MaybeUninit<T>);
+
+impl<T> Drop for MaybeDangling<T> {
+ fn drop(&mut self) {
+ // Safety: `0` is always initialized.
+ unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) };
+ }
+}
+
+impl<T> MaybeDangling<T> {
+ pub(crate) fn new(inner: T) -> Self {
+ Self(MaybeUninit::new(inner))
+ }
+}
+
+impl<F: Future> Future for MaybeDangling<F> {
+ type Output = F::Output;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ // Safety: `0` is always initialized.
+ let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) };
+ fut.poll(cx)
+ }
+}
+
+#[test]
+fn maybedangling_runs_drop() {
+ struct SetOnDrop<'a>(&'a mut bool);
+
+ impl Drop for SetOnDrop<'_> {
+ fn drop(&mut self) {
+ *self.0 = true;
+ }
+ }
+
+ let mut success = false;
+
+ drop(MaybeDangling::new(SetOnDrop(&mut success)));
+ assert!(success);
+}
diff --git a/src/util/mod.rs b/src/util/mod.rs
new file mode 100644
index 0000000..a17f25a
--- /dev/null
+++ b/src/util/mod.rs
@@ -0,0 +1,8 @@
+mod maybe_dangling;
+#[cfg(any(feature = "io", feature = "codec"))]
+mod poll_buf;
+
+pub(crate) use maybe_dangling::MaybeDangling;
+#[cfg(any(feature = "io", feature = "codec"))]
+#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+pub use poll_buf::{poll_read_buf, poll_write_buf};
diff --git a/src/util/poll_buf.rs b/src/util/poll_buf.rs
new file mode 100644
index 0000000..82af1bb
--- /dev/null
+++ b/src/util/poll_buf.rs
@@ -0,0 +1,145 @@
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+use bytes::{Buf, BufMut};
+use futures_core::ready;
+use std::io::{self, IoSlice};
+use std::mem::MaybeUninit;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
+///
+/// [`BufMut`]: bytes::Buf
+///
+/// # Example
+///
+/// ```
+/// use bytes::{Bytes, BytesMut};
+/// use tokio_stream as stream;
+/// use tokio::io::Result;
+/// use tokio_util::io::{StreamReader, poll_read_buf};
+/// use futures::future::poll_fn;
+/// use std::pin::Pin;
+/// # #[tokio::main]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a reader from an iterator. This particular reader will always be
+/// // ready.
+/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
+///
+/// let mut buf = BytesMut::new();
+/// let mut reads = 0;
+///
+/// loop {
+/// reads += 1;
+/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
+///
+/// if n == 0 {
+/// break;
+/// }
+/// }
+///
+/// // one or more reads might be necessary.
+/// assert!(reads >= 1);
+/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
+/// # Ok(())
+/// # }
+/// ```
+#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+) -> Poll<io::Result<usize>> {
+ if !buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = {
+ let dst = buf.chunk_mut();
+
+ // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
+ // transparent wrapper around `[MaybeUninit<u8>]`.
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ let ptr = buf.filled().as_ptr();
+ ready!(io.poll_read(cx, &mut buf)?);
+
+ // Ensure the pointer does not change from under us
+ assert_eq!(ptr, buf.filled().as_ptr());
+ buf.filled().len()
+ };
+
+ // Safety: This is guaranteed to be the number of initialized (and read)
+ // bytes due to the invariants provided by `ReadBuf::filled`.
+ unsafe {
+ buf.advance_mut(n);
+ }
+
+ Poll::Ready(Ok(n))
+}
+
+/// Try to write data from an implementer of the [`Buf`] trait to an
+/// [`AsyncWrite`], advancing the buffer's internal cursor.
+///
+/// This function will use [vectored writes] when the [`AsyncWrite`] supports
+/// vectored writes.
+///
+/// # Examples
+///
+/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
+/// [`Buf`]:
+///
+/// ```no_run
+/// use tokio_util::io::poll_write_buf;
+/// use tokio::io;
+/// use tokio::fs::File;
+///
+/// use bytes::Buf;
+/// use std::io::Cursor;
+/// use std::pin::Pin;
+/// use futures::future::poll_fn;
+///
+/// #[tokio::main]
+/// async fn main() -> io::Result<()> {
+/// let mut file = File::create("foo.txt").await?;
+/// let mut buf = Cursor::new(b"data to write");
+///
+/// // Loop until the entire contents of the buffer are written to
+/// // the file.
+/// while buf.has_remaining() {
+/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
+/// }
+///
+/// Ok(())
+/// }
+/// ```
+///
+/// [`Buf`]: bytes::Buf
+/// [`AsyncWrite`]: tokio::io::AsyncWrite
+/// [`File`]: tokio::fs::File
+/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
+#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+) -> Poll<io::Result<usize>> {
+ const MAX_BUFS: usize = 64;
+
+ if !buf.has_remaining() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = if io.is_write_vectored() {
+ let mut slices = [IoSlice::new(&[]); MAX_BUFS];
+ let cnt = buf.chunks_vectored(&mut slices);
+ ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
+ } else {
+ ready!(io.poll_write(cx, buf.chunk()))?
+ };
+
+ buf.advance(n);
+
+ Poll::Ready(Ok(n))
+}
diff --git a/tests/compat.rs b/tests/compat.rs
new file mode 100644
index 0000000..278ebfc
--- /dev/null
+++ b/tests/compat.rs
@@ -0,0 +1,43 @@
+#![cfg(all(feature = "compat"))]
+#![cfg(not(target_os = "wasi"))] // WASI does not support all fs operations
+#![warn(rust_2018_idioms)]
+
+use futures_io::SeekFrom;
+use futures_util::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
+use tempfile::NamedTempFile;
+use tokio::fs::OpenOptions;
+use tokio_util::compat::TokioAsyncWriteCompatExt;
+
+#[tokio::test]
+async fn compat_file_seek() -> futures_util::io::Result<()> {
+ let temp_file = NamedTempFile::new()?;
+ let mut file = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .create(true)
+ .open(temp_file)
+ .await?
+ .compat_write();
+
+ file.write_all(&[0, 1, 2, 3, 4, 5]).await?;
+ file.write_all(&[6, 7]).await?;
+
+ assert_eq!(file.stream_position().await?, 8);
+
+ // Modify elements at position 2.
+ assert_eq!(file.seek(SeekFrom::Start(2)).await?, 2);
+ file.write_all(&[8, 9]).await?;
+
+ file.flush().await?;
+
+ // Verify we still have 8 elements.
+ assert_eq!(file.seek(SeekFrom::End(0)).await?, 8);
+ // Seek back to the start of the file to read and verify contents.
+ file.seek(SeekFrom::Start(0)).await?;
+
+ let mut buf = Vec::new();
+ let num_bytes = file.read_to_end(&mut buf).await?;
+ assert_eq!(&buf[..num_bytes], &[0, 1, 8, 9, 4, 5, 6, 7]);
+
+ Ok(())
+}
diff --git a/tests/io_sync_bridge.rs b/tests/io_sync_bridge.rs
index 76bbd0b..50d0e89 100644
--- a/tests/io_sync_bridge.rs
+++ b/tests/io_sync_bridge.rs
@@ -44,6 +44,18 @@ async fn test_async_write_to_sync() -> Result<(), Box<dyn Error>> {
}
#[tokio::test]
+async fn test_into_inner() -> Result<(), Box<dyn Error>> {
+ let mut buf = Vec::new();
+ SyncIoBridge::new(tokio::io::empty())
+ .into_inner()
+ .read_to_end(&mut buf)
+ .await
+ .unwrap();
+ assert_eq!(buf.len(), 0);
+ Ok(())
+}
+
+#[tokio::test]
async fn test_shutdown() -> Result<(), Box<dyn Error>> {
let (s1, mut s2) = tokio::io::duplex(1024);
let (_rh, wh) = tokio::io::split(s1);
diff --git a/tests/length_delimited.rs b/tests/length_delimited.rs
index 126e41b..ed5590f 100644
--- a/tests/length_delimited.rs
+++ b/tests/length_delimited.rs
@@ -12,7 +12,6 @@ use futures::{pin_mut, Sink, Stream};
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
-use std::task::Poll::*;
use std::task::{Context, Poll};
macro_rules! mock {
@@ -39,10 +38,10 @@ macro_rules! assert_next_eq {
macro_rules! assert_next_pending {
($io:ident) => {{
task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) {
- Ready(Some(Ok(v))) => panic!("value = {:?}", v),
- Ready(Some(Err(e))) => panic!("error = {:?}", e),
- Ready(None) => panic!("done"),
- Pending => {}
+ Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v),
+ Poll::Ready(Some(Err(e))) => panic!("error = {:?}", e),
+ Poll::Ready(None) => panic!("done"),
+ Poll::Pending => {}
});
}};
}
@@ -50,10 +49,10 @@ macro_rules! assert_next_pending {
macro_rules! assert_next_err {
($io:ident) => {{
task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) {
- Ready(Some(Ok(v))) => panic!("value = {:?}", v),
- Ready(Some(Err(_))) => {}
- Ready(None) => panic!("done"),
- Pending => panic!("pending"),
+ Poll::Ready(Some(Ok(v))) => panic!("value = {:?}", v),
+ Poll::Ready(Some(Err(_))) => {}
+ Poll::Ready(None) => panic!("done"),
+ Poll::Pending => panic!("pending"),
});
}};
}
@@ -186,11 +185,11 @@ fn read_single_frame_multi_packet_wait() {
let io = FramedRead::new(
mock! {
data(b"\x00\x00"),
- Pending,
+ Poll::Pending,
data(b"\x00\x09abc"),
- Pending,
+ Poll::Pending,
data(b"defghi"),
- Pending,
+ Poll::Pending,
},
LengthDelimitedCodec::new(),
);
@@ -208,15 +207,15 @@ fn read_multi_frame_multi_packet_wait() {
let io = FramedRead::new(
mock! {
data(b"\x00\x00"),
- Pending,
+ Poll::Pending,
data(b"\x00\x09abc"),
- Pending,
+ Poll::Pending,
data(b"defghi"),
- Pending,
+ Poll::Pending,
data(b"\x00\x00\x00\x0312"),
- Pending,
+ Poll::Pending,
data(b"3\x00\x00\x00\x0bhello world"),
- Pending,
+ Poll::Pending,
},
LengthDelimitedCodec::new(),
);
@@ -250,9 +249,9 @@ fn read_incomplete_head() {
fn read_incomplete_head_multi() {
let io = FramedRead::new(
mock! {
- Pending,
+ Poll::Pending,
data(b"\x00"),
- Pending,
+ Poll::Pending,
},
LengthDelimitedCodec::new(),
);
@@ -268,9 +267,9 @@ fn read_incomplete_payload() {
let io = FramedRead::new(
mock! {
data(b"\x00\x00\x00\x09ab"),
- Pending,
+ Poll::Pending,
data(b"cd"),
- Pending,
+ Poll::Pending,
},
LengthDelimitedCodec::new(),
);
@@ -310,7 +309,7 @@ fn read_update_max_frame_len_at_rest() {
fn read_update_max_frame_len_in_flight() {
let io = length_delimited::Builder::new().new_read(mock! {
data(b"\x00\x00\x00\x09abcd"),
- Pending,
+ Poll::Pending,
data(b"efghi"),
data(b"\x00\x00\x00\x09abcdefghi"),
});
@@ -533,9 +532,9 @@ fn write_single_multi_frame_multi_packet() {
fn write_single_frame_would_block() {
let io = FramedWrite::new(
mock! {
- Pending,
+ Poll::Pending,
data(b"\x00\x00"),
- Pending,
+ Poll::Pending,
data(b"\x00\x09"),
data(b"abcdefghi"),
flush(),
@@ -640,7 +639,7 @@ fn write_update_max_frame_len_in_flight() {
let io = length_delimited::Builder::new().new_write(mock! {
data(b"\x00\x00\x00\x06"),
data(b"ab"),
- Pending,
+ Poll::Pending,
data(b"cdef"),
flush(),
});
@@ -701,8 +700,6 @@ enum Op {
Flush,
}
-use self::Op::*;
-
impl AsyncRead for Mock {
fn poll_read(
mut self: Pin<&mut Self>,
@@ -710,15 +707,15 @@ impl AsyncRead for Mock {
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.calls.pop_front() {
- Some(Ready(Ok(Op::Data(data)))) => {
+ Some(Poll::Ready(Ok(Op::Data(data)))) => {
debug_assert!(dst.remaining() >= data.len());
dst.put_slice(&data);
- Ready(Ok(()))
+ Poll::Ready(Ok(()))
}
- Some(Ready(Ok(_))) => panic!(),
- Some(Ready(Err(e))) => Ready(Err(e)),
- Some(Pending) => Pending,
- None => Ready(Ok(())),
+ Some(Poll::Ready(Ok(_))) => panic!(),
+ Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
+ Some(Poll::Pending) => Poll::Pending,
+ None => Poll::Ready(Ok(())),
}
}
}
@@ -730,31 +727,31 @@ impl AsyncWrite for Mock {
src: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.calls.pop_front() {
- Some(Ready(Ok(Op::Data(data)))) => {
+ Some(Poll::Ready(Ok(Op::Data(data)))) => {
let len = data.len();
assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src);
assert_eq!(&data[..], &src[..len]);
- Ready(Ok(len))
+ Poll::Ready(Ok(len))
}
- Some(Ready(Ok(_))) => panic!(),
- Some(Ready(Err(e))) => Ready(Err(e)),
- Some(Pending) => Pending,
- None => Ready(Ok(0)),
+ Some(Poll::Ready(Ok(_))) => panic!(),
+ Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
+ Some(Poll::Pending) => Poll::Pending,
+ None => Poll::Ready(Ok(0)),
}
}
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.calls.pop_front() {
- Some(Ready(Ok(Op::Flush))) => Ready(Ok(())),
- Some(Ready(Ok(_))) => panic!(),
- Some(Ready(Err(e))) => Ready(Err(e)),
- Some(Pending) => Pending,
- None => Ready(Ok(())),
+ Some(Poll::Ready(Ok(Op::Flush))) => Poll::Ready(Ok(())),
+ Some(Poll::Ready(Ok(_))) => panic!(),
+ Some(Poll::Ready(Err(e))) => Poll::Ready(Err(e)),
+ Some(Poll::Pending) => Poll::Pending,
+ None => Poll::Ready(Ok(())),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
- Ready(Ok(()))
+ Poll::Ready(Ok(()))
}
}
@@ -771,9 +768,9 @@ impl From<Vec<u8>> for Op {
}
fn data(bytes: &[u8]) -> Poll<io::Result<Op>> {
- Ready(Ok(bytes.into()))
+ Poll::Ready(Ok(bytes.into()))
}
fn flush() -> Poll<io::Result<Op>> {
- Ready(Ok(Flush))
+ Poll::Ready(Ok(Op::Flush))
}
diff --git a/tests/mpsc.rs b/tests/mpsc.rs
index a3c164d..74b83c2 100644
--- a/tests/mpsc.rs
+++ b/tests/mpsc.rs
@@ -28,6 +28,29 @@ async fn simple() {
}
#[tokio::test]
+async fn simple_ref() {
+ let v = vec![1, 2, 3i32];
+
+ let (send, mut recv) = channel(3);
+ let mut send = PollSender::new(send);
+
+ for vi in v.iter() {
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(vi).unwrap();
+ }
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+
+ assert_eq!(*recv.recv().await.unwrap(), 1);
+ assert!(reserve.is_woken());
+ assert_ready_ok!(reserve.poll());
+ drop(recv);
+ send.send_item(&42).unwrap();
+}
+
+#[tokio::test]
async fn repeated_poll_reserve() {
let (send, mut recv) = channel::<i32>(1);
let mut send = PollSender::new(send);
diff --git a/tests/task_join_map.rs b/tests/task_join_map.rs
index cef08b2..1ab5f9b 100644
--- a/tests/task_join_map.rs
+++ b/tests/task_join_map.rs
@@ -109,6 +109,30 @@ async fn alternating() {
}
}
+#[tokio::test]
+async fn test_keys() {
+ use std::collections::HashSet;
+
+ let mut map = JoinMap::new();
+
+ assert_eq!(map.len(), 0);
+ map.spawn(1, async {});
+ assert_eq!(map.len(), 1);
+ map.spawn(2, async {});
+ assert_eq!(map.len(), 2);
+
+ let keys = map.keys().collect::<HashSet<&u32>>();
+ assert!(keys.contains(&1));
+ assert!(keys.contains(&2));
+
+ let _ = map.join_next().await.unwrap();
+ let _ = map.join_next().await.unwrap();
+
+ assert_eq!(map.len(), 0);
+ let keys = map.keys().collect::<HashSet<&u32>>();
+ assert!(keys.is_empty());
+}
+
#[tokio::test(start_paused = true)]
async fn abort_by_key() {
let mut map = JoinMap::new();
diff --git a/tests/task_tracker.rs b/tests/task_tracker.rs
new file mode 100644
index 0000000..f0eb244
--- /dev/null
+++ b/tests/task_tracker.rs
@@ -0,0 +1,178 @@
+#![warn(rust_2018_idioms)]
+
+use tokio_test::{assert_pending, assert_ready, task};
+use tokio_util::task::TaskTracker;
+
+#[test]
+fn open_close() {
+ let tracker = TaskTracker::new();
+ assert!(!tracker.is_closed());
+ assert!(tracker.is_empty());
+ assert_eq!(tracker.len(), 0);
+
+ tracker.close();
+ assert!(tracker.is_closed());
+ assert!(tracker.is_empty());
+ assert_eq!(tracker.len(), 0);
+
+ tracker.reopen();
+ assert!(!tracker.is_closed());
+ tracker.reopen();
+ assert!(!tracker.is_closed());
+
+ assert!(tracker.is_empty());
+ assert_eq!(tracker.len(), 0);
+
+ tracker.close();
+ assert!(tracker.is_closed());
+ tracker.close();
+ assert!(tracker.is_closed());
+
+ assert!(tracker.is_empty());
+ assert_eq!(tracker.len(), 0);
+}
+
+#[test]
+fn token_len() {
+ let tracker = TaskTracker::new();
+
+ let mut tokens = Vec::new();
+ for i in 0..10 {
+ assert_eq!(tracker.len(), i);
+ tokens.push(tracker.token());
+ }
+
+ assert!(!tracker.is_empty());
+ assert_eq!(tracker.len(), 10);
+
+ for (i, token) in tokens.into_iter().enumerate() {
+ drop(token);
+ assert_eq!(tracker.len(), 9 - i);
+ }
+}
+
+#[test]
+fn notify_immediately() {
+ let tracker = TaskTracker::new();
+ tracker.close();
+
+ let mut wait = task::spawn(tracker.wait());
+ assert_ready!(wait.poll());
+}
+
+#[test]
+fn notify_immediately_on_reopen() {
+ let tracker = TaskTracker::new();
+ tracker.close();
+
+ let mut wait = task::spawn(tracker.wait());
+ tracker.reopen();
+ assert_ready!(wait.poll());
+}
+
+#[test]
+fn notify_on_close() {
+ let tracker = TaskTracker::new();
+
+ let mut wait = task::spawn(tracker.wait());
+
+ assert_pending!(wait.poll());
+ tracker.close();
+ assert_ready!(wait.poll());
+}
+
+#[test]
+fn notify_on_close_reopen() {
+ let tracker = TaskTracker::new();
+
+ let mut wait = task::spawn(tracker.wait());
+
+ assert_pending!(wait.poll());
+ tracker.close();
+ tracker.reopen();
+ assert_ready!(wait.poll());
+}
+
+#[test]
+fn notify_on_last_task() {
+ let tracker = TaskTracker::new();
+ tracker.close();
+ let token = tracker.token();
+
+ let mut wait = task::spawn(tracker.wait());
+ assert_pending!(wait.poll());
+ drop(token);
+ assert_ready!(wait.poll());
+}
+
+#[test]
+fn notify_on_last_task_respawn() {
+ let tracker = TaskTracker::new();
+ tracker.close();
+ let token = tracker.token();
+
+ let mut wait = task::spawn(tracker.wait());
+ assert_pending!(wait.poll());
+ drop(token);
+ let token2 = tracker.token();
+ assert_ready!(wait.poll());
+ drop(token2);
+}
+
+#[test]
+fn no_notify_on_respawn_if_open() {
+ let tracker = TaskTracker::new();
+ let token = tracker.token();
+
+ let mut wait = task::spawn(tracker.wait());
+ assert_pending!(wait.poll());
+ drop(token);
+ let token2 = tracker.token();
+ assert_pending!(wait.poll());
+ drop(token2);
+}
+
+#[test]
+fn close_during_exit() {
+ const ITERS: usize = 5;
+
+ for close_spot in 0..=ITERS {
+ let tracker = TaskTracker::new();
+ let tokens: Vec<_> = (0..ITERS).map(|_| tracker.token()).collect();
+
+ let mut wait = task::spawn(tracker.wait());
+
+ for (i, token) in tokens.into_iter().enumerate() {
+ assert_pending!(wait.poll());
+ if i == close_spot {
+ tracker.close();
+ assert_pending!(wait.poll());
+ }
+ drop(token);
+ }
+
+ if close_spot == ITERS {
+ assert_pending!(wait.poll());
+ tracker.close();
+ }
+
+ assert_ready!(wait.poll());
+ }
+}
+
+#[test]
+fn notify_many() {
+ let tracker = TaskTracker::new();
+
+ let mut waits: Vec<_> = (0..10).map(|_| task::spawn(tracker.wait())).collect();
+
+ for wait in &mut waits {
+ assert_pending!(wait.poll());
+ }
+
+ tracker.close();
+
+ for wait in &mut waits {
+ assert_ready!(wait.poll());
+ }
+}
diff --git a/tests/time_delay_queue.rs b/tests/time_delay_queue.rs
index 9ceae34..9b7b6cc 100644
--- a/tests/time_delay_queue.rs
+++ b/tests/time_delay_queue.rs
@@ -2,6 +2,7 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
+use futures::StreamExt;
use tokio::time::{self, sleep, sleep_until, Duration, Instant};
use tokio_test::{assert_pending, assert_ready, task};
use tokio_util::time::DelayQueue;
@@ -257,6 +258,10 @@ async fn reset_twice() {
#[tokio::test]
async fn repeatedly_reset_entry_inserted_as_expired() {
time::pause();
+
+ // Instants before the start of the test seem to break in wasm.
+ time::sleep(ms(1000)).await;
+
let mut queue = task::spawn(DelayQueue::new());
let now = Instant::now();
@@ -556,6 +561,10 @@ async fn reset_later_after_slot_starts() {
#[tokio::test]
async fn reset_inserted_expired() {
time::pause();
+
+ // Instants before the start of the test seem to break in wasm.
+ time::sleep(ms(1000)).await;
+
let mut queue = task::spawn(DelayQueue::new());
let now = Instant::now();
@@ -778,6 +787,22 @@ async fn compact_change_deadline() {
assert!(entry.is_none());
}
+#[tokio::test(start_paused = true)]
+async fn item_expiry_greater_than_wheel() {
+ // This function tests that a delay queue that has existed for at least 2^36 milliseconds won't panic when a new item is inserted.
+ let mut queue = DelayQueue::new();
+ for _ in 0..2 {
+ tokio::time::advance(Duration::from_millis(1 << 35)).await;
+ queue.insert(0, Duration::from_millis(0));
+ queue.next().await;
+ }
+ // This should not panic
+ let no_panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
+ queue.insert(1, Duration::from_millis(1));
+ }));
+ assert!(no_panic.is_ok());
+}
+
#[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")]
#[tokio::test(start_paused = true)]
async fn remove_after_compact() {
@@ -815,6 +840,44 @@ async fn remove_after_compact_poll() {
assert!(panic.is_err());
}
+#[tokio::test(start_paused = true)]
+async fn peek() {
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let key = queue.insert_at("foo", now + ms(5));
+ let key2 = queue.insert_at("bar", now);
+ let key3 = queue.insert_at("baz", now + ms(10));
+
+ assert_eq!(queue.peek(), Some(key2));
+
+ sleep(ms(6)).await;
+
+ assert_eq!(queue.peek(), Some(key2));
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(entry.get_ref(), &"bar");
+
+ assert_eq!(queue.peek(), Some(key));
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(entry.get_ref(), &"foo");
+
+ assert_eq!(queue.peek(), Some(key3));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(5)).await;
+
+ assert_eq!(queue.peek(), Some(key3));
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(entry.get_ref(), &"baz");
+
+ assert!(queue.peek().is_none());
+}
+
fn ms(n: u64) -> Duration {
Duration::from_millis(n)
}
diff --git a/tests/udp.rs b/tests/udp.rs
index 1b99806..db726a3 100644
--- a/tests/udp.rs
+++ b/tests/udp.rs
@@ -13,7 +13,10 @@ use futures::sink::SinkExt;
use std::io;
use std::sync::Arc;
-#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))]
+#[cfg_attr(
+ any(target_os = "macos", target_os = "ios", target_os = "tvos"),
+ allow(unused_assignments)
+)]
#[tokio::test]
async fn send_framed_byte_codec() -> std::io::Result<()> {
let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?;
@@ -41,7 +44,7 @@ async fn send_framed_byte_codec() -> std::io::Result<()> {
b_soc = b.into_inner();
}
- #[cfg(not(any(target_os = "macos", target_os = "ios")))]
+ #[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "tvos")))]
// test sending & receiving an empty message
{
let mut a = UdpFramed::new(a_soc, ByteCodec);