aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFedor Tsarev <ftsarev@google.com>2023-03-27 08:58:56 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-03-27 08:58:56 +0000
commit3bfdc11718fc001a7fa1d300b679c089379eac8c (patch)
tree795e44577cac958437be45a482f418434173be49
parent8d7c957c7242d7b0076a69d3ea9bbee4a285c79f (diff)
parentff8ca542a4c1702a894d16bf7f13cdee71b67536 (diff)
downloadtokio-util-3bfdc11718fc001a7fa1d300b679c089379eac8c.tar.gz
Import tokio-util crate am: b3b4e5b0f0 am: 0145148b25 am: 48663d53f9 am: 4fb43fe351 am: ff8ca542a4
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/tokio-util/+/2506418 Change-Id: If2575f0c48049ad556a4c07cb455b3b18613b798 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--CHANGELOG.md334
-rw-r--r--Cargo.toml131
-rw-r--r--Cargo.toml.orig66
-rw-r--r--LICENSE25
-rw-r--r--METADATA19
-rw-r--r--MODULE_LICENSE_MIT0
-rw-r--r--OWNERS1
-rw-r--r--README.md13
-rw-r--r--src/cfg.rs71
-rw-r--r--src/codec/any_delimiter_codec.rs263
-rw-r--r--src/codec/bytes_codec.rs86
-rw-r--r--src/codec/decoder.rs184
-rw-r--r--src/codec/encoder.rs25
-rw-r--r--src/codec/framed.rs383
-rw-r--r--src/codec/framed_impl.rs312
-rw-r--r--src/codec/framed_read.rs199
-rw-r--r--src/codec/framed_write.rs188
-rw-r--r--src/codec/length_delimited.rs1043
-rw-r--r--src/codec/lines_codec.rs230
-rw-r--r--src/codec/mod.rs290
-rw-r--r--src/compat.rs274
-rw-r--r--src/context.rs190
-rw-r--r--src/either.rs188
-rw-r--r--src/io/copy_to_bytes.rs68
-rw-r--r--src/io/inspect.rs134
-rw-r--r--src/io/mod.rs31
-rw-r--r--src/io/read_buf.rs65
-rw-r--r--src/io/reader_stream.rs118
-rw-r--r--src/io/sink_writer.rs117
-rw-r--r--src/io/stream_reader.rs326
-rw-r--r--src/io/sync_bridge.rs143
-rw-r--r--src/lib.rs205
-rw-r--r--src/loom.rs1
-rw-r--r--src/net/mod.rs97
-rw-r--r--src/net/unix/mod.rs18
-rw-r--r--src/sync/cancellation_token.rs324
-rw-r--r--src/sync/cancellation_token/guard.rs27
-rw-r--r--src/sync/cancellation_token/tree_node.rs373
-rw-r--r--src/sync/mod.rs15
-rw-r--r--src/sync/mpsc.rs284
-rw-r--r--src/sync/poll_semaphore.rs171
-rw-r--r--src/sync/reusable_box.rs171
-rw-r--r--src/sync/tests/loom_cancellation_token.rs176
-rw-r--r--src/sync/tests/mod.rs1
-rw-r--r--src/task/join_map.rs807
-rw-r--r--src/task/mod.rs12
-rw-r--r--src/task/spawn_pinned.rs436
-rw-r--r--src/time/delay_queue.rs1271
-rw-r--r--src/time/mod.rs47
-rw-r--r--src/time/wheel/level.rs253
-rw-r--r--src/time/wheel/mod.rs315
-rw-r--r--src/time/wheel/stack.rs28
-rw-r--r--src/udp/frame.rs249
-rw-r--r--src/udp/mod.rs4
-rw-r--r--tests/_require_full.rs2
-rw-r--r--tests/codecs.rs464
-rw-r--r--tests/context.rs25
-rw-r--r--tests/framed.rs152
-rw-r--r--tests/framed_read.rs339
-rw-r--r--tests/framed_stream.rs38
-rw-r--r--tests/framed_write.rs212
-rw-r--r--tests/io_inspect.rs194
-rw-r--r--tests/io_reader_stream.rs65
-rw-r--r--tests/io_sink_writer.rs72
-rw-r--r--tests/io_stream_reader.rs35
-rw-r--r--tests/io_sync_bridge.rs62
-rw-r--r--tests/length_delimited.rs779
-rw-r--r--tests/mpsc.rs239
-rw-r--r--tests/panic.rs226
-rw-r--r--tests/poll_semaphore.rs84
-rw-r--r--tests/reusable_box.rs85
-rw-r--r--tests/spawn_pinned.rs237
-rw-r--r--tests/sync_cancellation_token.rs447
-rw-r--r--tests/task_join_map.rs275
-rw-r--r--tests/time_delay_queue.rs820
-rw-r--r--tests/udp.rs133
76 files changed, 15787 insertions, 0 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..0c11b21
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,334 @@
+# 0.7.7 (February 12, 2023)
+
+This release reverts the removal of the `Encoder` bound on the `FramedParts`
+constructor from [#5280] since it turned out to be a breaking change. ([#5450])
+
+[#5450]: https://github.com/tokio-rs/tokio/pull/5450
+
+# 0.7.6 (February 10, 2023)
+
+This release fixes a compilation failure in 0.7.5 when it is used together with
+Tokio version 1.21 and unstable features are enabled. ([#5445])
+
+[#5445]: https://github.com/tokio-rs/tokio/pull/5445
+
+# 0.7.5 (February 9, 2023)
+
+This release fixes an accidental breaking change where `UnwindSafe` was
+accidentally removed from `CancellationToken`.
+
+### Added
+- codec: add `Framed::backpressure_boundary` ([#5124])
+- io: add `InspectReader` and `InspectWriter` ([#5033])
+- io: add `tokio_util::io::{CopyToBytes, SinkWriter}` ([#5070], [#5436])
+- io: impl `std::io::BufRead` on `SyncIoBridge` ([#5265])
+- sync: add `PollSemaphore::poll_acquire_many` ([#5137])
+- sync: add owned future for `CancellationToken` ([#5153])
+- time: add `DelayQueue::try_remove` ([#5052])
+
+### Fixed
+- codec: fix `LengthDelimitedCodec` buffer over-reservation ([#4997])
+- sync: impl `UnwindSafe` on `CancellationToken` ([#5438])
+- util: remove `Encoder` bound on `FramedParts` constructor ([#5280])
+
+### Documented
+- io: add lines example for `StreamReader` ([#5145])
+
+[#4997]: https://github.com/tokio-rs/tokio/pull/4997
+[#5033]: https://github.com/tokio-rs/tokio/pull/5033
+[#5052]: https://github.com/tokio-rs/tokio/pull/5052
+[#5070]: https://github.com/tokio-rs/tokio/pull/5070
+[#5124]: https://github.com/tokio-rs/tokio/pull/5124
+[#5137]: https://github.com/tokio-rs/tokio/pull/5137
+[#5145]: https://github.com/tokio-rs/tokio/pull/5145
+[#5153]: https://github.com/tokio-rs/tokio/pull/5153
+[#5265]: https://github.com/tokio-rs/tokio/pull/5265
+[#5280]: https://github.com/tokio-rs/tokio/pull/5280
+[#5436]: https://github.com/tokio-rs/tokio/pull/5436
+[#5438]: https://github.com/tokio-rs/tokio/pull/5438
+
+# 0.7.4 (September 8, 2022)
+
+### Added
+
+- io: add `SyncIoBridge::shutdown()` ([#4938])
+- task: improve `LocalPoolHandle` ([#4680])
+
+### Fixed
+
+- util: add `track_caller` to public APIs ([#4785])
+
+### Unstable
+
+- task: fix compilation errors in `JoinMap` with Tokio v1.21.0 ([#4755])
+- task: remove the unstable, deprecated `JoinMap::join_one` ([#4920])
+
+[#4680]: https://github.com/tokio-rs/tokio/pull/4680
+[#4755]: https://github.com/tokio-rs/tokio/pull/4755
+[#4785]: https://github.com/tokio-rs/tokio/pull/4785
+[#4920]: https://github.com/tokio-rs/tokio/pull/4920
+[#4938]: https://github.com/tokio-rs/tokio/pull/4938
+
+# 0.7.3 (June 4, 2022)
+
+### Changed
+
+- tracing: don't require default tracing features ([#4592])
+- util: simplify implementation of `ReusableBoxFuture` ([#4675])
+
+### Added (unstable)
+
+- task: add `JoinMap` ([#4640], [#4697])
+
+[#4592]: https://github.com/tokio-rs/tokio/pull/4592
+[#4640]: https://github.com/tokio-rs/tokio/pull/4640
+[#4675]: https://github.com/tokio-rs/tokio/pull/4675
+[#4697]: https://github.com/tokio-rs/tokio/pull/4697
+
+# 0.7.2 (May 14, 2022)
+
+This release contains a rewrite of `CancellationToken` that fixes a memory leak. ([#4652])
+
+[#4652]: https://github.com/tokio-rs/tokio/pull/4652
+
+# 0.7.1 (February 21, 2022)
+
+### Added
+
+- codec: add `length_field_type` to `LengthDelimitedCodec` builder ([#4508])
+- io: add `StreamReader::into_inner_with_chunk()` ([#4559])
+
+### Changed
+
+- switch from log to tracing ([#4539])
+
+### Fixed
+
+- sync: fix waker update condition in `CancellationToken` ([#4497])
+- bumped tokio dependency to 1.6 to satisfy minimum requirements ([#4490])
+
+[#4490]: https://github.com/tokio-rs/tokio/pull/4490
+[#4497]: https://github.com/tokio-rs/tokio/pull/4497
+[#4508]: https://github.com/tokio-rs/tokio/pull/4508
+[#4539]: https://github.com/tokio-rs/tokio/pull/4539
+[#4559]: https://github.com/tokio-rs/tokio/pull/4559
+
+# 0.7.0 (February 9, 2022)
+
+### Added
+
+- task: add `spawn_pinned` ([#3370])
+- time: add `shrink_to_fit` and `compact` methods to `DelayQueue` ([#4170])
+- codec: improve `Builder::max_frame_length` docs ([#4352])
+- codec: add mutable reference getters for codecs to pinned `Framed` ([#4372])
+- net: add generic trait to combine `UnixListener` and `TcpListener` ([#4385])
+- codec: implement `Framed::map_codec` ([#4427])
+- codec: implement `Encoder<BytesMut>` for `BytesCodec` ([#4465])
+
+### Changed
+
+- sync: add lifetime parameter to `ReusableBoxFuture` ([#3762])
+- sync: refactored `PollSender<T>` to fix a subtly broken `Sink<T>` implementation ([#4214])
+- time: remove error case from the infallible `DelayQueue::poll_elapsed` ([#4241])
+
+[#3370]: https://github.com/tokio-rs/tokio/pull/3370
+[#4170]: https://github.com/tokio-rs/tokio/pull/4170
+[#4352]: https://github.com/tokio-rs/tokio/pull/4352
+[#4372]: https://github.com/tokio-rs/tokio/pull/4372
+[#4385]: https://github.com/tokio-rs/tokio/pull/4385
+[#4427]: https://github.com/tokio-rs/tokio/pull/4427
+[#4465]: https://github.com/tokio-rs/tokio/pull/4465
+[#3762]: https://github.com/tokio-rs/tokio/pull/3762
+[#4214]: https://github.com/tokio-rs/tokio/pull/4214
+[#4241]: https://github.com/tokio-rs/tokio/pull/4241
+
+# 0.6.10 (May 14, 2021)
+
+This is a backport for the memory leak in `CancellationToken` that was originally fixed in 0.7.2. ([#4652])
+
+[#4652]: https://github.com/tokio-rs/tokio/pull/4652
+
+# 0.6.9 (October 29, 2021)
+
+### Added
+
+- codec: implement `Clone` for `LengthDelimitedCodec` ([#4089])
+- io: add `SyncIoBridge` ([#4146])
+
+### Fixed
+
+- time: update deadline on removal in `DelayQueue` ([#4178])
+- codec: Update stream impl for Framed to return None after Err ([#4166])
+
+[#4089]: https://github.com/tokio-rs/tokio/pull/4089
+[#4146]: https://github.com/tokio-rs/tokio/pull/4146
+[#4166]: https://github.com/tokio-rs/tokio/pull/4166
+[#4178]: https://github.com/tokio-rs/tokio/pull/4178
+
+# 0.6.8 (September 3, 2021)
+
+### Added
+
+- sync: add drop guard for `CancellationToken` ([#3839])
+- compact: added `AsyncSeek` compat ([#4078])
+- time: expose `Key` used in `DelayQueue`'s `Expired` ([#4081])
+- io: add `with_capacity` to `ReaderStream` ([#4086])
+
+### Fixed
+
+- codec: remove unnecessary `doc(cfg(...))` ([#3989])
+
+[#3839]: https://github.com/tokio-rs/tokio/pull/3839
+[#4078]: https://github.com/tokio-rs/tokio/pull/4078
+[#4081]: https://github.com/tokio-rs/tokio/pull/4081
+[#4086]: https://github.com/tokio-rs/tokio/pull/4086
+[#3989]: https://github.com/tokio-rs/tokio/pull/3989
+
+# 0.6.7 (May 14, 2021)
+
+### Added
+
+- udp: make `UdpFramed` take `Borrow<UdpSocket>` ([#3451])
+- compat: implement `AsRawFd`/`AsRawHandle` for `Compat<T>` ([#3765])
+
+[#3451]: https://github.com/tokio-rs/tokio/pull/3451
+[#3765]: https://github.com/tokio-rs/tokio/pull/3765
+
+# 0.6.6 (April 12, 2021)
+
+### Added
+
+- util: makes `Framed` and `FramedStream` resumable after eof ([#3272])
+- util: add `PollSemaphore::{add_permits, available_permits}` ([#3683])
+
+### Fixed
+
+- chore: avoid allocation if `PollSemaphore` is unused ([#3634])
+
+[#3272]: https://github.com/tokio-rs/tokio/pull/3272
+[#3634]: https://github.com/tokio-rs/tokio/pull/3634
+[#3683]: https://github.com/tokio-rs/tokio/pull/3683
+
+# 0.6.5 (March 20, 2021)
+
+### Fixed
+
+- util: annotate time module as requiring `time` feature ([#3606])
+
+[#3606]: https://github.com/tokio-rs/tokio/pull/3606
+
+# 0.6.4 (March 9, 2021)
+
+### Added
+
+- codec: `AnyDelimiter` codec ([#3406])
+- sync: add pollable `mpsc::Sender` ([#3490])
+
+### Fixed
+
+- codec: `LinesCodec` should only return `MaxLineLengthExceeded` once per line ([#3556])
+- sync: fuse PollSemaphore ([#3578])
+
+[#3406]: https://github.com/tokio-rs/tokio/pull/3406
+[#3490]: https://github.com/tokio-rs/tokio/pull/3490
+[#3556]: https://github.com/tokio-rs/tokio/pull/3556
+[#3578]: https://github.com/tokio-rs/tokio/pull/3578
+
+# 0.6.3 (January 31, 2021)
+
+### Added
+
+- sync: add `ReusableBoxFuture` utility ([#3464])
+
+### Changed
+
+- sync: use `ReusableBoxFuture` for `PollSemaphore` ([#3463])
+- deps: remove `async-stream` dependency ([#3463])
+- deps: remove `tokio-stream` dependency ([#3487])
+
+# 0.6.2 (January 21, 2021)
+
+### Added
+
+- sync: add pollable `Semaphore` ([#3444])
+
+### Fixed
+
+- time: fix panics on updating `DelayQueue` entries ([#3270])
+
+# 0.6.1 (January 12, 2021)
+
+### Added
+
+- codec: `get_ref()`, `get_mut()`, `get_pin_mut()` and `into_inner()` for
+ `Framed`, `FramedRead`, `FramedWrite` and `StreamReader` ([#3364]).
+- codec: `write_buffer()` and `write_buffer_mut()` for `Framed` and
+ `FramedWrite` ([#3387]).
+
+# 0.6.0 (December 23, 2020)
+
+### Changed
+- depend on `tokio` 1.0.
+
+### Added
+- rt: add constructors to `TokioContext` (#3221).
+
+# 0.5.1 (December 3, 2020)
+
+### Added
+- io: `poll_read_buf` util fn (#2972).
+- io: `poll_write_buf` util fn with vectored write support (#3156).
+
+# 0.5.0 (October 30, 2020)
+
+### Changed
+- io: update `bytes` to 0.6 (#3071).
+
+# 0.4.0 (October 15, 2020)
+
+### Added
+- sync: `CancellationToken` for coordinating task cancellation (#2747).
+- rt: `TokioContext` sets the Tokio runtime for the duration of a future (#2791)
+- io: `StreamReader`/`ReaderStream` map between `AsyncRead` values and `Stream`
+ of bytes (#2788).
+- time: `DelayQueue` to manage many delays (#2897).
+
+# 0.3.1 (March 18, 2020)
+
+### Fixed
+
+- Adjust minimum-supported Tokio version to v0.2.5 to account for an internal
+ dependency on features in that version of Tokio. ([#2326])
+
+# 0.3.0 (March 4, 2020)
+
+### Changed
+
+- **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows
+ codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746])
+
+### Added
+
+- Add futures-io/tokio::io compatibility layer. ([#2117])
+- Add `Framed::with_capacity`. ([#2215])
+
+### Fixed
+
+- Use advance over split_to when data is not needed. ([#2198])
+
+# 0.2.0 (November 26, 2019)
+
+- Initial release
+
+[#3487]: https://github.com/tokio-rs/tokio/pull/3487
+[#3464]: https://github.com/tokio-rs/tokio/pull/3464
+[#3463]: https://github.com/tokio-rs/tokio/pull/3463
+[#3444]: https://github.com/tokio-rs/tokio/pull/3444
+[#3387]: https://github.com/tokio-rs/tokio/pull/3387
+[#3364]: https://github.com/tokio-rs/tokio/pull/3364
+[#3270]: https://github.com/tokio-rs/tokio/pull/3270
+[#2326]: https://github.com/tokio-rs/tokio/pull/2326
+[#2215]: https://github.com/tokio-rs/tokio/pull/2215
+[#2198]: https://github.com/tokio-rs/tokio/pull/2198
+[#2117]: https://github.com/tokio-rs/tokio/pull/2117
+[#1746]: https://github.com/tokio-rs/tokio/pull/1746
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..221f2e7
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,131 @@
+# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
+#
+# When uploading crates to the registry Cargo will automatically
+# "normalize" Cargo.toml files for maximal compatibility
+# with all versions of Cargo and also rewrite `path` dependencies
+# to registry (e.g., crates.io) dependencies.
+#
+# If you are reading this file be aware that the original Cargo.toml
+# will likely look very different (and much more reasonable).
+# See Cargo.toml.orig for the original contents.
+
+[package]
+edition = "2018"
+rust-version = "1.49"
+name = "tokio-util"
+version = "0.7.7"
+authors = ["Tokio Contributors <team@tokio.rs>"]
+description = """
+Additional utilities for working with Tokio.
+"""
+homepage = "https://tokio.rs"
+readme = "README.md"
+categories = ["asynchronous"]
+license = "MIT"
+repository = "https://github.com/tokio-rs/tokio"
+
+[package.metadata.docs.rs]
+all-features = true
+rustdoc-args = [
+ "--cfg",
+ "docsrs",
+ "--cfg",
+ "tokio_unstable",
+]
+rustc-args = [
+ "--cfg",
+ "docsrs",
+ "--cfg",
+ "tokio_unstable",
+]
+
+[dependencies.bytes]
+version = "1.0.0"
+
+[dependencies.futures-core]
+version = "0.3.0"
+
+[dependencies.futures-io]
+version = "0.3.0"
+optional = true
+
+[dependencies.futures-sink]
+version = "0.3.0"
+
+[dependencies.futures-util]
+version = "0.3.0"
+optional = true
+
+[dependencies.pin-project-lite]
+version = "0.2.0"
+
+[dependencies.slab]
+version = "0.4.4"
+optional = true
+
+[dependencies.tokio]
+version = "1.22.0"
+features = ["sync"]
+
+[dependencies.tracing]
+version = "0.1.25"
+features = ["std"]
+optional = true
+default-features = false
+
+[dev-dependencies.async-stream]
+version = "0.3.0"
+
+[dev-dependencies.futures]
+version = "0.3.0"
+
+[dev-dependencies.futures-test]
+version = "0.3.5"
+
+[dev-dependencies.parking_lot]
+version = "0.12.0"
+
+[dev-dependencies.tokio]
+version = "1.0.0"
+features = ["full"]
+
+[dev-dependencies.tokio-stream]
+version = "0.1"
+
+[dev-dependencies.tokio-test]
+version = "0.4.0"
+
+[features]
+__docs_rs = ["futures-util"]
+codec = ["tracing"]
+compat = ["futures-io"]
+default = []
+full = [
+ "codec",
+ "compat",
+ "io-util",
+ "time",
+ "net",
+ "rt",
+]
+io = []
+io-util = [
+ "io",
+ "tokio/rt",
+ "tokio/io-util",
+]
+net = ["tokio/net"]
+rt = [
+ "tokio/rt",
+ "tokio/sync",
+ "futures-util",
+ "hashbrown",
+]
+time = [
+ "tokio/time",
+ "slab",
+]
+
+[target."cfg(tokio_unstable)".dependencies.hashbrown]
+version = "0.12.0"
+optional = true
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
new file mode 100644
index 0000000..267662b
--- /dev/null
+++ b/Cargo.toml.orig
@@ -0,0 +1,66 @@
+[package]
+name = "tokio-util"
+# When releasing to crates.io:
+# - Remove path dependencies
+# - Update CHANGELOG.md.
+# - Create "tokio-util-0.7.x" git tag.
+version = "0.7.7"
+edition = "2018"
+rust-version = "1.49"
+authors = ["Tokio Contributors <team@tokio.rs>"]
+license = "MIT"
+repository = "https://github.com/tokio-rs/tokio"
+homepage = "https://tokio.rs"
+description = """
+Additional utilities for working with Tokio.
+"""
+categories = ["asynchronous"]
+
+[features]
+# No features on by default
+default = []
+
+# Shorthand for enabling everything
+full = ["codec", "compat", "io-util", "time", "net", "rt"]
+
+net = ["tokio/net"]
+compat = ["futures-io",]
+codec = ["tracing"]
+time = ["tokio/time","slab"]
+io = []
+io-util = ["io", "tokio/rt", "tokio/io-util"]
+rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"]
+
+__docs_rs = ["futures-util"]
+
+[dependencies]
+tokio = { version = "1.22.0", path = "../tokio", features = ["sync"] }
+bytes = "1.0.0"
+futures-core = "0.3.0"
+futures-sink = "0.3.0"
+futures-io = { version = "0.3.0", optional = true }
+futures-util = { version = "0.3.0", optional = true }
+pin-project-lite = "0.2.0"
+slab = { version = "0.4.4", optional = true } # Backs `DelayQueue`
+tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true }
+
+[target.'cfg(tokio_unstable)'.dependencies]
+hashbrown = { version = "0.12.0", optional = true }
+
+[dev-dependencies]
+tokio = { version = "1.0.0", path = "../tokio", features = ["full"] }
+tokio-test = { version = "0.4.0", path = "../tokio-test" }
+tokio-stream = { version = "0.1", path = "../tokio-stream" }
+
+async-stream = "0.3.0"
+futures = "0.3.0"
+futures-test = "0.3.5"
+parking_lot = "0.12.0"
+
+[package.metadata.docs.rs]
+all-features = true
+# enable unstable features in the documentation
+rustdoc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"]
+# it's necessary to _also_ pass `--cfg tokio_unstable` to rustc, or else
+# dependencies will not be enabled, and the docs build will fail.
+rustc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..8bdf6bd
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,25 @@
+Copyright (c) 2023 Tokio Contributors
+
+Permission is hereby granted, free of charge, to any
+person obtaining a copy of this software and associated
+documentation files (the "Software"), to deal in the
+Software without restriction, including without
+limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software
+is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice
+shall be included in all copies or substantial portions
+of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
diff --git a/METADATA b/METADATA
new file mode 100644
index 0000000..28dbb73
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,19 @@
+name: "tokio-util"
+description: "Utilities for working with Tokio."
+third_party {
+ url {
+ type: HOMEPAGE
+ value: "https://crates.io/crates/tokio-util"
+ }
+ url {
+ type: ARCHIVE
+ value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.7.crate"
+ }
+ version: "0.7.7"
+ license_type: NOTICE
+ last_upgrade_date {
+ year: 2023
+ month: 3
+ day: 3
+ }
+}
diff --git a/MODULE_LICENSE_MIT b/MODULE_LICENSE_MIT
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/MODULE_LICENSE_MIT
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..45dc4dd
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1 @@
+include platform/prebuilts/rust:master:/OWNERS
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..0d74f36
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+# tokio-util
+
+Utilities for working with Tokio.
+
+## License
+
+This project is licensed under the [MIT license](LICENSE).
+
+### Contribution
+
+Unless you explicitly state otherwise, any contribution intentionally submitted
+for inclusion in Tokio by you, shall be licensed as MIT, without any additional
+terms or conditions.
diff --git a/src/cfg.rs b/src/cfg.rs
new file mode 100644
index 0000000..4035255
--- /dev/null
+++ b/src/cfg.rs
@@ -0,0 +1,71 @@
+macro_rules! cfg_codec {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "codec")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "codec")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_compat {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "compat")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "compat")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_net {
+ ($($item:item)*) => {
+ $(
+ #[cfg(all(feature = "net", feature = "codec"))]
+ #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_io {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "io")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "io")))]
+ $item
+ )*
+ }
+}
+
+cfg_io! {
+ macro_rules! cfg_io_util {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "io-util")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
+ $item
+ )*
+ }
+ }
+}
+
+macro_rules! cfg_rt {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_time {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "time")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "time")))]
+ $item
+ )*
+ }
+}
diff --git a/src/codec/any_delimiter_codec.rs b/src/codec/any_delimiter_codec.rs
new file mode 100644
index 0000000..3dbfd45
--- /dev/null
+++ b/src/codec/any_delimiter_codec.rs
@@ -0,0 +1,263 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::{cmp, fmt, io, str, usize};
+
+const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r";
+const DEFAULT_SEQUENCE_WRITER: &[u8] = b",";
+/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+///
+/// # Example
+/// Decode string of bytes containing various different delimiters.
+///
+/// [`BytesMut`]: bytes::BytesMut
+/// [`Error`]: std::io::Error
+///
+/// ```
+/// use tokio_util::codec::{AnyDelimiterCodec, Decoder};
+/// use bytes::{BufMut, BytesMut};
+///
+/// #
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> Result<(), std::io::Error> {
+/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec());
+/// let buf = &mut BytesMut::new();
+/// buf.reserve(200);
+/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r");
+/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!("", codec.decode(buf).unwrap().unwrap());
+/// assert_eq!(None, codec.decode(buf).unwrap());
+/// # Ok(())
+/// # }
+/// ```
+///
+#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub struct AnyDelimiterCodec {
+ // Stored index of the next index to examine for the delimiter character.
+ // This is used to optimize searching.
+ // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`,
+ // because that is the next index to examine.
+ // The next time `decode` is called with `abcde}`, the method will
+ // only look at `de}` before returning.
+ next_index: usize,
+
+ /// The maximum length for a given chunk. If `usize::MAX`, chunks will be
+ /// read until a delimiter character is reached.
+ max_length: usize,
+
+ /// Are we currently discarding the remainder of a chunk which was over
+ /// the length limit?
+ is_discarding: bool,
+
+ /// The bytes that are using for search during decode
+ seek_delimiters: Vec<u8>,
+
+ /// The bytes that are using for encoding
+ sequence_writer: Vec<u8>,
+}
+
+impl AnyDelimiterCodec {
+ /// Returns a `AnyDelimiterCodec` for splitting up data into chunks.
+ ///
+ /// # Note
+ ///
+ /// The returned `AnyDelimiterCodec` will not have an upper bound on the length
+ /// of a buffered chunk. See the documentation for [`new_with_max_length`]
+ /// for information on why this could be a potential security risk.
+ ///
+ /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length()
+ pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec {
+ AnyDelimiterCodec {
+ next_index: 0,
+ max_length: usize::MAX,
+ is_discarding: false,
+ seek_delimiters,
+ sequence_writer,
+ }
+ }
+
+ /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit.
+ ///
+ /// If this is set, calls to `AnyDelimiterCodec::decode` will return a
+ /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls
+ /// will discard up to `limit` bytes from that chunk until a delimiter
+ /// character is reached, returning `None` until the delimiter over the limit
+ /// has been fully discarded. After that point, calls to `decode` will
+ /// function as normal.
+ ///
+ /// # Note
+ ///
+ /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which
+ /// will be exposed to untrusted input. Otherwise, the size of the buffer
+ /// that holds the chunk currently being read is unbounded. An attacker could
+ /// exploit this unbounded buffer by sending an unbounded amount of input
+ /// without any delimiter characters, causing unbounded memory consumption.
+ ///
+ /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError
+ pub fn new_with_max_length(
+ seek_delimiters: Vec<u8>,
+ sequence_writer: Vec<u8>,
+ max_length: usize,
+ ) -> Self {
+ AnyDelimiterCodec {
+ max_length,
+ ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer)
+ }
+ }
+
+ /// Returns the maximum chunk length when decoding.
+ ///
+ /// ```
+ /// use std::usize;
+ /// use tokio_util::codec::AnyDelimiterCodec;
+ ///
+ /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec());
+ /// assert_eq!(codec.max_length(), usize::MAX);
+ /// ```
+ /// ```
+ /// use tokio_util::codec::AnyDelimiterCodec;
+ ///
+ /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256);
+ /// assert_eq!(codec.max_length(), 256);
+ /// ```
+ pub fn max_length(&self) -> usize {
+ self.max_length
+ }
+}
+
+impl Decoder for AnyDelimiterCodec {
+ type Item = Bytes;
+ type Error = AnyDelimiterCodecError;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
+ loop {
+ // Determine how far into the buffer we'll search for a delimiter. If
+ // there's no max_length set, we'll read to the end of the buffer.
+ let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
+
+ let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| {
+ self.seek_delimiters
+ .iter()
+ .any(|delimiter| *b == *delimiter)
+ });
+
+ match (self.is_discarding, new_chunk_offset) {
+ (true, Some(offset)) => {
+ // If we found a new chunk, discard up to that offset and
+ // then stop discarding. On the next iteration, we'll try
+ // to read a chunk normally.
+ buf.advance(offset + self.next_index + 1);
+ self.is_discarding = false;
+ self.next_index = 0;
+ }
+ (true, None) => {
+ // Otherwise, we didn't find a new chunk, so we'll discard
+ // everything we read. On the next iteration, we'll continue
+ // discarding up to max_len bytes unless we find a new chunk.
+ buf.advance(read_to);
+ self.next_index = 0;
+ if buf.is_empty() {
+ return Ok(None);
+ }
+ }
+ (false, Some(offset)) => {
+ // Found a chunk!
+ let new_chunk_index = offset + self.next_index;
+ self.next_index = 0;
+ let mut chunk = buf.split_to(new_chunk_index + 1);
+ chunk.truncate(chunk.len() - 1);
+ let chunk = chunk.freeze();
+ return Ok(Some(chunk));
+ }
+ (false, None) if buf.len() > self.max_length => {
+ // Reached the maximum length without finding a
+ // new chunk, return an error and start discarding on the
+ // next call.
+ self.is_discarding = true;
+ return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded);
+ }
+ (false, None) => {
+ // We didn't find a chunk or reach the length limit, so the next
+ // call will resume searching at the current offset.
+ self.next_index = read_to;
+ return Ok(None);
+ }
+ }
+ }
+ }
+
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
+ Ok(match self.decode(buf)? {
+ Some(frame) => Some(frame),
+ None => {
+ // return remaining data, if any
+ if buf.is_empty() {
+ None
+ } else {
+ let chunk = buf.split_to(buf.len());
+ self.next_index = 0;
+ Some(chunk.freeze())
+ }
+ }
+ })
+ }
+}
+
+impl<T> Encoder<T> for AnyDelimiterCodec
+where
+ T: AsRef<str>,
+{
+ type Error = AnyDelimiterCodecError;
+
+ fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> {
+ let chunk = chunk.as_ref();
+ buf.reserve(chunk.len() + 1);
+ buf.put(chunk.as_bytes());
+ buf.put(self.sequence_writer.as_ref());
+
+ Ok(())
+ }
+}
+
+impl Default for AnyDelimiterCodec {
+ fn default() -> Self {
+ Self::new(
+ DEFAULT_SEEK_DELIMITERS.to_vec(),
+ DEFAULT_SEQUENCE_WRITER.to_vec(),
+ )
+ }
+}
+
+/// An error occurred while encoding or decoding a chunk.
+#[derive(Debug)]
+pub enum AnyDelimiterCodecError {
+ /// The maximum chunk length was exceeded.
+ MaxChunkLengthExceeded,
+ /// An IO error occurred.
+ Io(io::Error),
+}
+
+impl fmt::Display for AnyDelimiterCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ AnyDelimiterCodecError::MaxChunkLengthExceeded => {
+ write!(f, "max chunk length exceeded")
+ }
+ AnyDelimiterCodecError::Io(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+impl From<io::Error> for AnyDelimiterCodecError {
+ fn from(e: io::Error) -> AnyDelimiterCodecError {
+ AnyDelimiterCodecError::Io(e)
+ }
+}
+
+impl std::error::Error for AnyDelimiterCodecError {}
diff --git a/src/codec/bytes_codec.rs b/src/codec/bytes_codec.rs
new file mode 100644
index 0000000..ceab228
--- /dev/null
+++ b/src/codec/bytes_codec.rs
@@ -0,0 +1,86 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{BufMut, Bytes, BytesMut};
+use std::io;
+
+/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+///
+/// # Example
+///
+/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`.
+///
+/// [`AsyncRead`]: tokio::io::AsyncRead
+/// [`BytesMut`]: bytes::BytesMut
+/// [`Error`]: std::io::Error
+///
+/// ```
+/// # mod hidden {
+/// # #[allow(unused_imports)]
+/// use tokio::fs::File;
+/// # }
+/// use tokio::io::AsyncRead;
+/// use tokio_util::codec::{FramedRead, BytesCodec};
+///
+/// # enum File {}
+/// # impl File {
+/// # async fn open(_name: &str) -> Result<impl AsyncRead, std::io::Error> {
+/// # use std::io::Cursor;
+/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5]))
+/// # }
+/// # }
+/// #
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> Result<(), std::io::Error> {
+/// let my_async_read = File::open("filename.txt").await?;
+/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new());
+/// # Ok(())
+/// # }
+/// ```
+///
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
+pub struct BytesCodec(());
+
+impl BytesCodec {
+ /// Creates a new `BytesCodec` for shipping around raw bytes.
+ pub fn new() -> BytesCodec {
+ BytesCodec(())
+ }
+}
+
+impl Decoder for BytesCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
+ if !buf.is_empty() {
+ let len = buf.len();
+ Ok(Some(buf.split_to(len)))
+ } else {
+ Ok(None)
+ }
+ }
+}
+
+impl Encoder<Bytes> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
+ buf.reserve(data.len());
+ buf.put(data);
+ Ok(())
+ }
+}
+
+impl Encoder<BytesMut> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
+ buf.reserve(data.len());
+ buf.put(data);
+ Ok(())
+ }
+}
diff --git a/src/codec/decoder.rs b/src/codec/decoder.rs
new file mode 100644
index 0000000..84ecb22
--- /dev/null
+++ b/src/codec/decoder.rs
@@ -0,0 +1,184 @@
+use crate::codec::Framed;
+
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use std::io;
+
+/// Decoding of frames via buffers.
+///
+/// This trait is used when constructing an instance of [`Framed`] or
+/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has
+/// already been buffered in `src` and decodes the data into a stream of
+/// `Self::Item` frames.
+///
+/// Implementations are able to track state on `self`, which enables
+/// implementing stateful streaming parsers. In many cases, though, this type
+/// will simply be a unit struct (e.g. `struct HttpDecoder`).
+///
+/// For some underlying data-sources, namely files and FIFOs,
+/// it's possible to temporarily read 0 bytes by reaching EOF.
+///
+/// In these cases `decode_eof` will be called until it signals
+/// fulfillment of all closing frames by returning `Ok(None)`.
+/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`]
+/// will not invoke `decode` or `decode_eof` again, until data can be read
+/// during a retry.
+///
+/// It is up to the Decoder to keep track of a restart after an EOF,
+/// and to decide how to handle such an event by, for example,
+/// allowing frames to cross EOF boundaries, re-emitting opening frames, or
+/// resetting the entire internal state.
+///
+/// [`Framed`]: crate::codec::Framed
+/// [`FramedRead`]: crate::codec::FramedRead
+pub trait Decoder {
+ /// The type of decoded frames.
+ type Item;
+
+ /// The type of unrecoverable frame decoding errors.
+ ///
+ /// If an individual message is ill-formed but can be ignored without
+ /// interfering with the processing of future messages, it may be more
+ /// useful to report the failure as an `Item`.
+ ///
+ /// `From<io::Error>` is required in the interest of making `Error` suitable
+ /// for returning directly from a [`FramedRead`], and to enable the default
+ /// implementation of `decode_eof` to yield an `io::Error` when the decoder
+ /// fails to consume all available data.
+ ///
+ /// Note that implementors of this trait can simply indicate `type Error =
+ /// io::Error` to use I/O errors as this type.
+ ///
+ /// [`FramedRead`]: crate::codec::FramedRead
+ type Error: From<io::Error>;
+
+ /// Attempts to decode a frame from the provided buffer of bytes.
+ ///
+ /// This method is called by [`FramedRead`] whenever bytes are ready to be
+ /// parsed. The provided buffer of bytes is what's been read so far, and
+ /// this instance of `Decode` can determine whether an entire frame is in
+ /// the buffer and is ready to be returned.
+ ///
+ /// If an entire frame is available, then this instance will remove those
+ /// bytes from the buffer provided and return them as a decoded
+ /// frame. Note that removing bytes from the provided buffer doesn't always
+ /// necessarily copy the bytes, so this should be an efficient operation in
+ /// most circumstances.
+ ///
+ /// If the bytes look valid, but a frame isn't fully available yet, then
+ /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that
+ /// it needs to read some more bytes before calling this method again.
+ ///
+ /// Note that the bytes provided may be empty. If a previous call to
+ /// `decode` consumed all the bytes in the buffer then `decode` will be
+ /// called again until it returns `Ok(None)`, indicating that more bytes need to
+ /// be read.
+ ///
+ /// Finally, if the bytes in the buffer are malformed then an error is
+ /// returned indicating why. This informs [`Framed`] that the stream is now
+ /// corrupt and should be terminated.
+ ///
+ /// [`Framed`]: crate::codec::Framed
+ /// [`FramedRead`]: crate::codec::FramedRead
+ ///
+ /// # Buffer management
+ ///
+ /// Before returning from the function, implementations should ensure that
+ /// the buffer has appropriate capacity in anticipation of future calls to
+ /// `decode`. Failing to do so leads to inefficiency.
+ ///
+ /// For example, if frames have a fixed length, or if the length of the
+ /// current frame is known from a header, a possible buffer management
+ /// strategy is:
+ ///
+ /// ```no_run
+ /// # use std::io;
+ /// #
+ /// # use bytes::BytesMut;
+ /// # use tokio_util::codec::Decoder;
+ /// #
+ /// # struct MyCodec;
+ /// #
+ /// impl Decoder for MyCodec {
+ /// // ...
+ /// # type Item = BytesMut;
+ /// # type Error = io::Error;
+ ///
+ /// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
+ /// // ...
+ ///
+ /// // Reserve enough to complete decoding of the current frame.
+ /// let current_frame_len: usize = 1000; // Example.
+ /// // And to start decoding the next frame.
+ /// let next_frame_header_len: usize = 10; // Example.
+ /// src.reserve(current_frame_len + next_frame_header_len);
+ ///
+ /// return Ok(None);
+ /// }
+ /// }
+ /// ```
+ ///
+ /// An optimal buffer management strategy minimizes reallocations and
+ /// over-allocations.
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
+
+ /// A default method available to be called when there are no more bytes
+ /// available to be read from the underlying I/O.
+ ///
+ /// This method defaults to calling `decode` and returns an error if
+ /// `Ok(None)` is returned while there is unconsumed data in `buf`.
+ /// Typically this doesn't need to be implemented unless the framing
+ /// protocol differs near the end of the stream, or if you need to construct
+ /// frames _across_ eof boundaries on sources that can be resumed.
+ ///
+ /// Note that the `buf` argument may be empty. If a previous call to
+ /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be
+ /// called again until it returns `None`, indicating that there are no more
+ /// frames to yield. This behavior enables returning finalization frames
+ /// that may not be based on inbound data.
+ ///
+ /// Once `None` has been returned, `decode_eof` won't be called again until
+ /// an attempt to resume the stream has been made, where the underlying stream
+ /// actually returned more data.
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
+ match self.decode(buf)? {
+ Some(frame) => Ok(Some(frame)),
+ None => {
+ if buf.is_empty() {
+ Ok(None)
+ } else {
+ Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into())
+ }
+ }
+ }
+ }
+
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// `Io` object, using `Decode` and `Encode` to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the `Codec`
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both `Stream` and
+ /// `Sink`; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling `split` on the [`Framed`] returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Framed`]: crate::codec::Framed
+ fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self>
+ where
+ Self: Sized,
+ {
+ Framed::new(io, self)
+ }
+}
diff --git a/src/codec/encoder.rs b/src/codec/encoder.rs
new file mode 100644
index 0000000..770a10f
--- /dev/null
+++ b/src/codec/encoder.rs
@@ -0,0 +1,25 @@
+use bytes::BytesMut;
+use std::io;
+
+/// Trait of helper objects to write out messages as bytes, for use with
+/// [`FramedWrite`].
+///
+/// [`FramedWrite`]: crate::codec::FramedWrite
+pub trait Encoder<Item> {
+ /// The type of encoding errors.
+ ///
+ /// [`FramedWrite`] requires `Encoder`s errors to implement `From<io::Error>`
+ /// in the interest letting it return `Error`s directly.
+ ///
+ /// [`FramedWrite`]: crate::codec::FramedWrite
+ type Error: From<io::Error>;
+
+ /// Encodes a frame into the buffer provided.
+ ///
+ /// This method will encode `item` into the byte buffer provided by `dst`.
+ /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and
+ /// will be written out when possible.
+ ///
+ /// [`FramedWrite`]: crate::codec::FramedWrite
+ fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>;
+}
diff --git a/src/codec/framed.rs b/src/codec/framed.rs
new file mode 100644
index 0000000..8a344f9
--- /dev/null
+++ b/src/codec/framed.rs
@@ -0,0 +1,383 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
+
+use futures_core::Stream;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using
+ /// the `Encoder` and `Decoder` traits to encode and decode frames.
+ ///
+ /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or
+ /// by using the `new` function seen below.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Decoder::framed`]: crate::codec::Decoder::framed()
+ pub struct Framed<T, U> {
+ #[pin]
+ inner: FramedImpl<T, U, RWFrames>
+ }
+}
+
+impl<T, U> Framed<T, U>
+where
+ T: AsyncRead + AsyncWrite,
+{
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the codec
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// Note that, for some byte sources, the stream can be resumed after an EOF
+ /// by reading from it, even after it has returned `None`. Repeated attempts
+ /// to do so, without new data available, continue to return `None` without
+ /// creating more (closing) frames.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decode`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn new(inner: T, codec: U) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner,
+ codec,
+ state: Default::default(),
+ },
+ }
+ }
+
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data,
+ /// with a specific read buffer initial capacity.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the codec
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decode`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner,
+ codec,
+ state: RWFrames {
+ read: ReadFrame {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(capacity),
+ has_errored: false,
+ },
+ write: WriteFrame::default(),
+ },
+ },
+ }
+ }
+}
+
+impl<T, U> Framed<T, U> {
+ /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
+ /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
+ ///
+ /// Raw I/O objects work with byte sequences, but higher-level code usually
+ /// wants to batch these into meaningful chunks, called "frames". This
+ /// method layers framing on top of an I/O object, by using the `Codec`
+ /// traits to handle encoding and decoding of messages frames. Note that
+ /// the incoming and outgoing frame types may be distinct.
+ ///
+ /// This function returns a *single* object that is both [`Stream`] and
+ /// [`Sink`]; grouping this into a single object is often useful for layering
+ /// things like gzip or TLS, which require both read and write access to the
+ /// underlying object.
+ ///
+ /// This objects takes a stream and a readbuffer and a writebuffer. These field
+ /// can be obtained from an existing `Framed` with the [`into_parts`] method.
+ ///
+ /// If you want to work more directly with the streams and sink, consider
+ /// calling [`split`] on the `Framed` returned by this method, which will
+ /// break them into separate objects, allowing them to interact more easily.
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`Sink`]: futures_sink::Sink
+ /// [`Decoder`]: crate::codec::Decoder
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`into_parts`]: crate::codec::Framed::into_parts()
+ /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+ pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
+ Framed {
+ inner: FramedImpl {
+ inner: parts.io,
+ codec: parts.codec,
+ state: RWFrames {
+ read: parts.read_buf.into(),
+ write: parts.write_buf.into(),
+ },
+ },
+ }
+ }
+
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Returns a reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec(&self) -> &U {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_mut(&mut self) -> &mut U {
+ &mut self.inner.codec
+ }
+
+ /// Maps the codec `U` to `C`, preserving the read and write buffers
+ /// wrapped by `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn map_codec<C, F>(self, map: F) -> Framed<T, C>
+ where
+ F: FnOnce(U) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let parts = self.into_parts();
+ Framed::from_parts(FramedParts {
+ io: parts.io,
+ codec: map(parts.codec),
+ read_buf: parts.read_buf,
+ write_buf: parts.write_buf,
+ _priv: (),
+ })
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.inner.state.read.buffer
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.read.buffer
+ }
+
+ /// Returns a reference to the write buffer.
+ pub fn write_buffer(&self) -> &BytesMut {
+ &self.inner.state.write.buffer
+ }
+
+ /// Returns a mutable reference to the write buffer.
+ pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.write.buffer
+ }
+
+ /// Returns backpressure boundary
+ pub fn backpressure_boundary(&self) -> usize {
+ self.inner.state.write.backpressure_boundary
+ }
+
+ /// Updates backpressure boundary
+ pub fn set_backpressure_boundary(&mut self, boundary: usize) {
+ self.inner.state.write.backpressure_boundary = boundary;
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream, the buffer
+ /// with unprocessed data, and the codec.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_parts(self) -> FramedParts<T, U> {
+ FramedParts {
+ io: self.inner.inner,
+ codec: self.inner.codec,
+ read_buf: self.inner.state.read.buffer,
+ write_buf: self.inner.state.write.buffer,
+ _priv: (),
+ }
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, U> Stream for Framed<T, U>
+where
+ T: AsyncRead,
+ U: Decoder,
+{
+ type Item = Result<U::Item, U::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, I, U> Sink<I> for Framed<T, U>
+where
+ T: AsyncWrite,
+ U: Encoder<I>,
+ U::Error: From<io::Error>,
+{
+ type Error = U::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+impl<T, U> fmt::Debug for Framed<T, U>
+where
+ T: fmt::Debug,
+ U: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Framed")
+ .field("io", self.get_ref())
+ .field("codec", self.codec())
+ .finish()
+ }
+}
+
+/// `FramedParts` contains an export of the data of a Framed transport.
+/// It can be used to construct a new [`Framed`] with a different codec.
+/// It contains all current buffers and the inner transport.
+///
+/// [`Framed`]: crate::codec::Framed
+#[derive(Debug)]
+#[allow(clippy::manual_non_exhaustive)]
+pub struct FramedParts<T, U> {
+ /// The inner transport used to read bytes to and write bytes to
+ pub io: T,
+
+ /// The codec
+ pub codec: U,
+
+ /// The buffer with read but unprocessed data.
+ pub read_buf: BytesMut,
+
+ /// A buffer with unprocessed data which are not written yet.
+ pub write_buf: BytesMut,
+
+ /// This private field allows us to add additional fields in the future in a
+ /// backwards compatible way.
+ _priv: (),
+}
+
+impl<T, U> FramedParts<T, U> {
+ /// Create a new, default, `FramedParts`
+ pub fn new<I>(io: T, codec: U) -> FramedParts<T, U>
+ where
+ U: Encoder<I>,
+ {
+ FramedParts {
+ io,
+ codec,
+ read_buf: BytesMut::new(),
+ write_buf: BytesMut::new(),
+ _priv: (),
+ }
+ }
+}
diff --git a/src/codec/framed_impl.rs b/src/codec/framed_impl.rs
new file mode 100644
index 0000000..8f3fa49
--- /dev/null
+++ b/src/codec/framed_impl.rs
@@ -0,0 +1,312 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use futures_core::Stream;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use futures_core::ready;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::borrow::{Borrow, BorrowMut};
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tracing::trace;
+
+pin_project! {
+ #[derive(Debug)]
+ pub(crate) struct FramedImpl<T, U, State> {
+ #[pin]
+ pub(crate) inner: T,
+ pub(crate) state: State,
+ pub(crate) codec: U,
+ }
+}
+
+const INITIAL_CAPACITY: usize = 8 * 1024;
+
+#[derive(Debug)]
+pub(crate) struct ReadFrame {
+ pub(crate) eof: bool,
+ pub(crate) is_readable: bool,
+ pub(crate) buffer: BytesMut,
+ pub(crate) has_errored: bool,
+}
+
+pub(crate) struct WriteFrame {
+ pub(crate) buffer: BytesMut,
+ pub(crate) backpressure_boundary: usize,
+}
+
+#[derive(Default)]
+pub(crate) struct RWFrames {
+ pub(crate) read: ReadFrame,
+ pub(crate) write: WriteFrame,
+}
+
+impl Default for ReadFrame {
+ fn default() -> Self {
+ Self {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ has_errored: false,
+ }
+ }
+}
+
+impl Default for WriteFrame {
+ fn default() -> Self {
+ Self {
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ backpressure_boundary: INITIAL_CAPACITY,
+ }
+ }
+}
+
+impl From<BytesMut> for ReadFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self {
+ buffer,
+ is_readable: size > 0,
+ eof: false,
+ has_errored: false,
+ }
+ }
+}
+
+impl From<BytesMut> for WriteFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self {
+ buffer,
+ backpressure_boundary: INITIAL_CAPACITY,
+ }
+ }
+}
+
+impl Borrow<ReadFrame> for RWFrames {
+ fn borrow(&self) -> &ReadFrame {
+ &self.read
+ }
+}
+impl BorrowMut<ReadFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut ReadFrame {
+ &mut self.read
+ }
+}
+impl Borrow<WriteFrame> for RWFrames {
+ fn borrow(&self) -> &WriteFrame {
+ &self.write
+ }
+}
+impl BorrowMut<WriteFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut WriteFrame {
+ &mut self.write
+ }
+}
+impl<T, U, R> Stream for FramedImpl<T, U, R>
+where
+ T: AsyncRead,
+ U: Decoder,
+ R: BorrowMut<ReadFrame>,
+{
+ type Item = Result<U::Item, U::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ use crate::util::poll_read_buf;
+
+ let mut pinned = self.project();
+ let state: &mut ReadFrame = pinned.state.borrow_mut();
+ // The following loops implements a state machine with each state corresponding
+ // to a combination of the `is_readable` and `eof` flags. States persist across
+ // loop entries and most state transitions occur with a return.
+ //
+ // The initial state is `reading`.
+ //
+ // | state | eof | is_readable | has_errored |
+ // |---------|-------|-------------|-------------|
+ // | reading | false | false | false |
+ // | framing | false | true | false |
+ // | pausing | true | true | false |
+ // | paused | true | false | false |
+ // | errored | <any> | <any> | true |
+ // `decode_eof` returns Err
+ // ┌────────────────────────────────────────────────────────┐
+ // `decode_eof` returns │ │
+ // `Ok(Some)` │ │
+ // ┌─────┐ │ `decode_eof` returns After returning │
+ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
+ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
+ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
+ // Pending read │ │ │ │ │ │
+ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
+ // │ │ │ ┌──────┐ │ Pending │ │
+ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
+ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
+ // └──┬─▲────┘ └─────┬──┬┘ │ │
+ // │ │ │ │ `decode` returns Err │ │
+ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
+ // │ read returns Err │
+ // └────────────────────────────────────────────────────────────────────────────────────────────┘
+ loop {
+ // Return `None` if we have encountered an error from the underlying decoder
+ // See: https://github.com/tokio-rs/tokio/issues/3976
+ if state.has_errored {
+ // preparing has_errored -> paused
+ trace!("Returning None and setting paused");
+ state.is_readable = false;
+ state.has_errored = false;
+ return Poll::Ready(None);
+ }
+
+ // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
+ // i.e. it _might_ contain data consumable as a frame or closing frame.
+ // Both signal that there is no such data by returning `None`.
+ //
+ // If `decode` couldn't read a frame and the upstream source has returned eof,
+ // `decode_eof` will attempt to decode the remaining bytes as closing frames.
+ //
+ // If the underlying AsyncRead is resumable, we may continue after an EOF,
+ // but must finish emitting all of it's associated `decode_eof` frames.
+ // Furthermore, we don't want to emit any `decode_eof` frames on retried
+ // reads after an EOF unless we've actually read more data.
+ if state.is_readable {
+ // pausing or framing
+ if state.eof {
+ // pausing
+ let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ })?;
+ if frame.is_none() {
+ state.is_readable = false; // prepare pausing -> paused
+ }
+ // implicit pausing -> pausing or pausing -> paused
+ return Poll::Ready(frame.map(Ok));
+ }
+
+ // framing
+ trace!("attempting to decode a frame");
+
+ if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ op
+ })? {
+ trace!("frame decoded from buffer");
+ // implicit framing -> framing
+ return Poll::Ready(Some(Ok(frame)));
+ }
+
+ // framing -> reading
+ state.is_readable = false;
+ }
+ // reading or paused
+ // If we can't build a frame yet, try to read more data and try again.
+ // Make sure we've got room for at least one byte to read to ensure
+ // that we don't get a spurious 0 that looks like EOF.
+ state.buffer.reserve(1);
+ let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
+ |err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ },
+ )? {
+ Poll::Ready(ct) => ct,
+ // implicit reading -> reading or implicit paused -> paused
+ Poll::Pending => return Poll::Pending,
+ };
+ if bytect == 0 {
+ if state.eof {
+ // We're already at an EOF, and since we've reached this path
+ // we're also not readable. This implies that we've already finished
+ // our `decode_eof` handling, so we can simply return `None`.
+ // implicit paused -> paused
+ return Poll::Ready(None);
+ }
+ // prepare reading -> paused
+ state.eof = true;
+ } else {
+ // prepare paused -> framing or noop reading -> framing
+ state.eof = false;
+ }
+
+ // paused -> framing or reading -> framing or reading -> pausing
+ state.is_readable = true;
+ }
+ }
+}
+
+impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
+where
+ T: AsyncWrite,
+ U: Encoder<I>,
+ U::Error: From<io::Error>,
+ W: BorrowMut<WriteFrame>,
+{
+ type Error = U::Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
+ self.as_mut().poll_flush(cx)
+ } else {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ let pinned = self.project();
+ pinned
+ .codec
+ .encode(item, &mut pinned.state.borrow_mut().buffer)?;
+ Ok(())
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ use crate::util::poll_write_buf;
+ trace!("flushing framed transport");
+ let mut pinned = self.project();
+
+ while !pinned.state.borrow_mut().buffer.is_empty() {
+ let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
+ trace!(remaining = buffer.len(), "writing;");
+
+ let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
+
+ if n == 0 {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::WriteZero,
+ "failed to \
+ write frame to transport",
+ )
+ .into()));
+ }
+ }
+
+ // Try flushing the underlying IO
+ ready!(pinned.inner.poll_flush(cx))?;
+
+ trace!("framed transport flushed");
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ ready!(self.project().inner.poll_shutdown(cx))?;
+
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs
new file mode 100644
index 0000000..184c567
--- /dev/null
+++ b/src/codec/framed_read.rs
@@ -0,0 +1,199 @@
+use crate::codec::framed_impl::{FramedImpl, ReadFrame};
+use crate::codec::Decoder;
+
+use futures_core::Stream;
+use tokio::io::AsyncRead;
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A [`Stream`] of messages decoded from an [`AsyncRead`].
+ ///
+ /// [`Stream`]: futures_core::Stream
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ pub struct FramedRead<T, D> {
+ #[pin]
+ inner: FramedImpl<T, D, ReadFrame>,
+ }
+}
+
+// ===== impl FramedRead =====
+
+impl<T, D> FramedRead<T, D>
+where
+ T: AsyncRead,
+ D: Decoder,
+{
+ /// Creates a new `FramedRead` with the given `decoder`.
+ pub fn new(inner: T, decoder: D) -> FramedRead<T, D> {
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ codec: decoder,
+ state: Default::default(),
+ },
+ }
+ }
+
+ /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity`
+ /// initial size.
+ pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead<T, D> {
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ codec: decoder,
+ state: ReadFrame {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(capacity),
+ has_errored: false,
+ },
+ },
+ }
+ }
+}
+
+impl<T, D> FramedRead<T, D> {
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `FramedRead`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Consumes the `FramedRead`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Returns a reference to the underlying decoder.
+ pub fn decoder(&self) -> &D {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying decoder.
+ pub fn decoder_mut(&mut self) -> &mut D {
+ &mut self.inner.codec
+ }
+
+ /// Maps the decoder `D` to `C`, preserving the read buffer
+ /// wrapped by `Framed`.
+ pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C>
+ where
+ F: FnOnce(D) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let FramedImpl {
+ inner,
+ state,
+ codec,
+ } = self.inner;
+ FramedRead {
+ inner: FramedImpl {
+ inner,
+ state,
+ codec: map(codec),
+ },
+ }
+ }
+
+ /// Returns a mutable reference to the underlying decoder.
+ pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.inner.state.buffer
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.buffer
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, D> Stream for FramedRead<T, D>
+where
+ T: AsyncRead,
+ D: Decoder,
+{
+ type Item = Result<D::Item, D::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.poll_next(cx)
+ }
+}
+
+// This impl just defers to the underlying T: Sink
+impl<T, I, D> Sink<I> for FramedRead<T, D>
+where
+ T: Sink<I>,
+{
+ type Error = T::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.project().inner.poll_close(cx)
+ }
+}
+
+impl<T, D> fmt::Debug for FramedRead<T, D>
+where
+ T: fmt::Debug,
+ D: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedRead")
+ .field("inner", &self.get_ref())
+ .field("decoder", &self.decoder())
+ .field("eof", &self.inner.state.eof)
+ .field("is_readable", &self.inner.state.is_readable)
+ .field("buffer", &self.read_buffer())
+ .finish()
+ }
+}
diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs
new file mode 100644
index 0000000..3f0a340
--- /dev/null
+++ b/src/codec/framed_write.rs
@@ -0,0 +1,188 @@
+use crate::codec::encoder::Encoder;
+use crate::codec::framed_impl::{FramedImpl, WriteFrame};
+
+use futures_core::Stream;
+use tokio::io::AsyncWrite;
+
+use bytes::BytesMut;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::fmt;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A [`Sink`] of frames encoded to an `AsyncWrite`.
+ ///
+ /// [`Sink`]: futures_sink::Sink
+ pub struct FramedWrite<T, E> {
+ #[pin]
+ inner: FramedImpl<T, E, WriteFrame>,
+ }
+}
+
+impl<T, E> FramedWrite<T, E>
+where
+ T: AsyncWrite,
+{
+ /// Creates a new `FramedWrite` with the given `encoder`.
+ pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> {
+ FramedWrite {
+ inner: FramedImpl {
+ inner,
+ codec: encoder,
+ state: WriteFrame::default(),
+ },
+ }
+ }
+}
+
+impl<T, E> FramedWrite<T, E> {
+ /// Returns a reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_ref(&self) -> &T {
+ &self.inner.inner
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner.inner
+ }
+
+ /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
+ /// `FramedWrite`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
+ self.project().inner.project().inner
+ }
+
+ /// Consumes the `FramedWrite`, returning its underlying I/O stream.
+ ///
+ /// Note that care should be taken to not tamper with the underlying stream
+ /// of data coming in as it may corrupt the stream of frames otherwise
+ /// being worked with.
+ pub fn into_inner(self) -> T {
+ self.inner.inner
+ }
+
+ /// Returns a reference to the underlying encoder.
+ pub fn encoder(&self) -> &E {
+ &self.inner.codec
+ }
+
+ /// Returns a mutable reference to the underlying encoder.
+ pub fn encoder_mut(&mut self) -> &mut E {
+ &mut self.inner.codec
+ }
+
+ /// Maps the encoder `E` to `C`, preserving the write buffer
+ /// wrapped by `Framed`.
+ pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
+ where
+ F: FnOnce(E) -> C,
+ {
+ // This could be potentially simplified once rust-lang/rust#86555 hits stable
+ let FramedImpl {
+ inner,
+ state,
+ codec,
+ } = self.inner;
+ FramedWrite {
+ inner: FramedImpl {
+ inner,
+ state,
+ codec: map(codec),
+ },
+ }
+ }
+
+ /// Returns a mutable reference to the underlying encoder.
+ pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
+ self.project().inner.project().codec
+ }
+
+ /// Returns a reference to the write buffer.
+ pub fn write_buffer(&self) -> &BytesMut {
+ &self.inner.state.buffer
+ }
+
+ /// Returns a mutable reference to the write buffer.
+ pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.inner.state.buffer
+ }
+
+ /// Returns backpressure boundary
+ pub fn backpressure_boundary(&self) -> usize {
+ self.inner.state.backpressure_boundary
+ }
+
+ /// Updates backpressure boundary
+ pub fn set_backpressure_boundary(&mut self, boundary: usize) {
+ self.inner.state.backpressure_boundary = boundary;
+ }
+}
+
+// This impl just defers to the underlying FramedImpl
+impl<T, I, E> Sink<I> for FramedWrite<T, E>
+where
+ T: AsyncWrite,
+ E: Encoder<I>,
+ E::Error: From<io::Error>,
+{
+ type Error = E::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+// This impl just defers to the underlying T: Stream
+impl<T, D> Stream for FramedWrite<T, D>
+where
+ T: Stream,
+{
+ type Item = T::Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.project().inner.project().inner.poll_next(cx)
+ }
+}
+
+impl<T, U> fmt::Debug for FramedWrite<T, U>
+where
+ T: fmt::Debug,
+ U: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedWrite")
+ .field("inner", &self.get_ref())
+ .field("encoder", &self.encoder())
+ .field("buffer", &self.inner.state.buffer)
+ .finish()
+ }
+}
diff --git a/src/codec/length_delimited.rs b/src/codec/length_delimited.rs
new file mode 100644
index 0000000..a182dca
--- /dev/null
+++ b/src/codec/length_delimited.rs
@@ -0,0 +1,1043 @@
+//! Frame a stream of bytes based on a length prefix
+//!
+//! Many protocols delimit their frames by prefacing frame data with a
+//! frame head that specifies the length of the frame. The
+//! `length_delimited` module provides utilities for handling the length
+//! based framing. This allows the consumer to work with entire frames
+//! without having to worry about buffering or other framing logic.
+//!
+//! # Getting started
+//!
+//! If implementing a protocol from scratch, using length delimited framing
+//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will
+//! return a length delimited codec using default configuration values.
+//! This can then be used to construct a framer to adapt a full-duplex
+//! byte stream into a stream of frames.
+//!
+//! ```
+//! use tokio::io::{AsyncRead, AsyncWrite};
+//! use tokio_util::codec::{Framed, LengthDelimitedCodec};
+//!
+//! fn bind_transport<T: AsyncRead + AsyncWrite>(io: T)
+//! -> Framed<T, LengthDelimitedCodec>
+//! {
+//! Framed::new(io, LengthDelimitedCodec::new())
+//! }
+//! # pub fn main() {}
+//! ```
+//!
+//! The returned transport implements `Sink + Stream` for `BytesMut`. It
+//! encodes the frame with a big-endian `u32` header denoting the frame
+//! payload length:
+//!
+//! ```text
+//! +----------+--------------------------------+
+//! | len: u32 | frame payload |
+//! +----------+--------------------------------+
+//! ```
+//!
+//! Specifically, given the following:
+//!
+//! ```
+//! use tokio::io::{AsyncRead, AsyncWrite};
+//! use tokio_util::codec::{Framed, LengthDelimitedCodec};
+//!
+//! use futures::SinkExt;
+//! use bytes::Bytes;
+//!
+//! async fn write_frame<T>(io: T) -> Result<(), Box<dyn std::error::Error>>
+//! where
+//! T: AsyncRead + AsyncWrite + Unpin,
+//! {
+//! let mut transport = Framed::new(io, LengthDelimitedCodec::new());
+//! let frame = Bytes::from("hello world");
+//!
+//! transport.send(frame).await?;
+//! Ok(())
+//! }
+//! ```
+//!
+//! The encoded frame will look like this:
+//!
+//! ```text
+//! +---- len: u32 ----+---- data ----+
+//! | \x00\x00\x00\x0b | hello world |
+//! +------------------+--------------+
+//! ```
+//!
+//! # Decoding
+//!
+//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`],
+//! such that each yielded [`BytesMut`] value contains the contents of an
+//! entire frame. There are many configuration parameters enabling
+//! [`FramedRead`] to handle a wide range of protocols. Here are some
+//! examples that will cover the various options at a high level.
+//!
+//! ## Example 1
+//!
+//! The following will parse a `u16` length field at offset 0, including the
+//! frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(0) // default value
+//! .num_skip(0) // Do not strip frame header
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+
+//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world |
+//! +----------+---------------+ +----------+---------------+
+//! ```
+//!
+//! The value of the length field is 11 (`\x0B`) which represents the length
+//! of the payload, `hello world`. By default, [`FramedRead`] assumes that
+//! the length field represents the number of bytes that **follows** the
+//! length field. Thus, the entire frame has a length of 13: 2 bytes for the
+//! frame head + 11 bytes for the payload.
+//!
+//! ## Example 2
+//!
+//! The following will parse a `u16` length field at offset 0, omitting the
+//! frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(0) // default value
+//! // `num_skip` is not needed, the default is to skip
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +--- Payload ---+
+//! | \x00\x0B | Hello world | --> | Hello world |
+//! +----------+---------------+ +---------------+
+//! ```
+//!
+//! This is similar to the first example, the only difference is that the
+//! frame head is **not** included in the yielded `BytesMut` value.
+//!
+//! ## Example 3
+//!
+//! The following will parse a `u16` length field at offset 0, including the
+//! frame head in the yielded `BytesMut`. In this case, the length field
+//! **includes** the frame head length.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_type::<u16>()
+//! .length_adjustment(-2) // size of head
+//! .num_skip(0)
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+
+//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world |
+//! +----------+---------------+ +----------+---------------+
+//! ```
+//!
+//! In most cases, the length field represents the length of the payload
+//! only, as shown in the previous examples. However, in some protocols the
+//! length field represents the length of the whole frame, including the
+//! head. In such cases, we specify a negative `length_adjustment` to adjust
+//! the value provided in the frame head to represent the payload length.
+//!
+//! ## Example 4
+//!
+//! The following will parse a 3 byte length field at offset 0 in a 5 byte
+//! frame head, including the frame head in the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_length(3)
+//! .length_adjustment(2) // remaining head
+//! .num_skip(0)
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +---- len -----+- head -+--- Payload ---+
+//! | \x00\x00\x0B | \xCAFE | Hello world |
+//! +--------------+--------+---------------+
+//!
+//! DECODED
+//! +---- len -----+- head -+--- Payload ---+
+//! | \x00\x00\x0B | \xCAFE | Hello world |
+//! +--------------+--------+---------------+
+//! ```
+//!
+//! A more advanced example that shows a case where there is extra frame
+//! head data between the length field and the payload. In such cases, it is
+//! usually desirable to include the frame head as part of the yielded
+//! `BytesMut`. This lets consumers of the length delimited framer to
+//! process the frame head as needed.
+//!
+//! The positive `length_adjustment` value lets `FramedRead` factor in the
+//! additional head into the frame length calculation.
+//!
+//! ## Example 5
+//!
+//! The following will parse a `u16` length field at offset 1 of a 4 byte
+//! frame head. The first byte and the length field will be omitted from the
+//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be
+//! included.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(1) // length of hdr1
+//! .length_field_type::<u16>()
+//! .length_adjustment(1) // length of hdr2
+//! .num_skip(3) // length of hdr1 + LEN
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+
+//! | \xCA | \x00\x0B | \xFE | Hello world |
+//! +--------+----------+--------+---------------+
+//!
+//! DECODED
+//! +- hdr2 -+--- Payload ---+
+//! | \xFE | Hello world |
+//! +--------+---------------+
+//! ```
+//!
+//! The length field is situated in the middle of the frame head. In this
+//! case, the first byte in the frame head could be a version or some other
+//! identifier that is not needed for processing. On the other hand, the
+//! second half of the head is needed.
+//!
+//! `length_field_offset` indicates how many bytes to skip before starting
+//! to read the length field. `length_adjustment` is the number of bytes to
+//! skip starting at the end of the length field. In this case, it is the
+//! second half of the head.
+//!
+//! ## Example 6
+//!
+//! The following will parse a `u16` length field at offset 1 of a 4 byte
+//! frame head. The first byte and the length field will be omitted from the
+//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be
+//! included. In this case, the length field **includes** the frame head
+//! length.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(1) // length of hdr1
+//! .length_field_type::<u16>()
+//! .length_adjustment(-3) // length of hdr1 + LEN, negative
+//! .num_skip(3)
+//! .new_read(io);
+//! # }
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT
+//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+
+//! | \xCA | \x00\x0F | \xFE | Hello world |
+//! +--------+----------+--------+---------------+
+//!
+//! DECODED
+//! +- hdr2 -+--- Payload ---+
+//! | \xFE | Hello world |
+//! +--------+---------------+
+//! ```
+//!
+//! Similar to the example above, the difference is that the length field
+//! represents the length of the entire frame instead of just the payload.
+//! The length of `hdr1` and `len` must be counted in `length_adjustment`.
+//! Note that the length of `hdr2` does **not** need to be explicitly set
+//! anywhere because it already is factored into the total frame length that
+//! is read from the byte stream.
+//!
+//! ## Example 7
+//!
+//! The following will parse a 3 byte length field at offset 0 in a 4 byte
+//! frame head, excluding the 4th byte from the yielded `BytesMut`.
+//!
+//! ```
+//! # use tokio::io::AsyncRead;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn bind_read<T: AsyncRead>(io: T) {
+//! LengthDelimitedCodec::builder()
+//! .length_field_offset(0) // default value
+//! .length_field_length(3)
+//! .length_adjustment(0) // default value
+//! .num_skip(4) // skip the first 4 bytes
+//! .new_read(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! The following frame will be decoded as such:
+//!
+//! ```text
+//! INPUT DECODED
+//! +------- len ------+--- Payload ---+ +--- Payload ---+
+//! | \x00\x00\x0B\xFF | Hello world | => | Hello world |
+//! +------------------+---------------+ +---------------+
+//! ```
+//!
+//! A simple example where there are unused bytes between the length field
+//! and the payload.
+//!
+//! # Encoding
+//!
+//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`],
+//! such that each submitted [`BytesMut`] is prefaced by a length field.
+//! There are fewer configuration options than [`FramedRead`]. Given
+//! protocols that have more complex frame heads, an encoder should probably
+//! be written by hand using [`Encoder`].
+//!
+//! Here is a simple example, given a `FramedWrite` with the following
+//! configuration:
+//!
+//! ```
+//! # use tokio::io::AsyncWrite;
+//! # use tokio_util::codec::LengthDelimitedCodec;
+//! # fn write_frame<T: AsyncWrite>(io: T) {
+//! # let _ =
+//! LengthDelimitedCodec::builder()
+//! .length_field_type::<u16>()
+//! .new_write(io);
+//! # }
+//! # pub fn main() {}
+//! ```
+//!
+//! A payload of `hello world` will be encoded as:
+//!
+//! ```text
+//! +- len: u16 -+---- data ----+
+//! | \x00\x0b | hello world |
+//! +------------+--------------+
+//! ```
+//!
+//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new
+//! [`FramedRead`]: struct@FramedRead
+//! [`FramedWrite`]: struct@FramedWrite
+//! [`AsyncRead`]: trait@tokio::io::AsyncRead
+//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite
+//! [`Encoder`]: trait@Encoder
+//! [`BytesMut`]: bytes::BytesMut
+
+use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};
+
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::error::Error as StdError;
+use std::io::{self, Cursor};
+use std::{cmp, fmt, mem};
+
+/// Configure length delimited `LengthDelimitedCodec`s.
+///
+/// `Builder` enables constructing configured length delimited codecs. Note
+/// that not all configuration settings apply to both encoding and decoding. See
+/// the documentation for specific methods for more detail.
+#[derive(Debug, Clone, Copy)]
+pub struct Builder {
+ // Maximum frame length
+ max_frame_len: usize,
+
+ // Number of bytes representing the field length
+ length_field_len: usize,
+
+ // Number of bytes in the header before the length field
+ length_field_offset: usize,
+
+ // Adjust the length specified in the header field by this amount
+ length_adjustment: isize,
+
+ // Total number of bytes to skip before reading the payload, if not set,
+ // `length_field_len + length_field_offset`
+ num_skip: Option<usize>,
+
+ // Length field byte order (little or big endian)
+ length_field_is_big_endian: bool,
+}
+
+/// An error when the number of bytes read is more than max frame length.
+pub struct LengthDelimitedCodecError {
+ _priv: (),
+}
+
+/// A codec for frames delimited by a frame head specifying their lengths.
+///
+/// This allows the consumer to work with entire frames without having to worry
+/// about buffering or other framing logic.
+///
+/// See [module level] documentation for more detail.
+///
+/// [module level]: index.html
+#[derive(Debug, Clone)]
+pub struct LengthDelimitedCodec {
+ // Configuration values
+ builder: Builder,
+
+ // Read state
+ state: DecodeState,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum DecodeState {
+ Head,
+ Data(usize),
+}
+
+// ===== impl LengthDelimitedCodec ======
+
+impl LengthDelimitedCodec {
+ /// Creates a new `LengthDelimitedCodec` with the default configuration values.
+ pub fn new() -> Self {
+ Self {
+ builder: Builder::new(),
+ state: DecodeState::Head,
+ }
+ }
+
+ /// Creates a new length delimited codec builder with default configuration
+ /// values.
+ pub fn builder() -> Builder {
+ Builder::new()
+ }
+
+ /// Returns the current max frame setting
+ ///
+ /// This is the largest size this codec will accept from the wire. Larger
+ /// frames will be rejected.
+ pub fn max_frame_length(&self) -> usize {
+ self.builder.max_frame_len
+ }
+
+ /// Updates the max frame setting.
+ ///
+ /// The change takes effect the next time a frame is decoded. In other
+ /// words, if a frame is currently in process of being decoded with a frame
+ /// size greater than `val` but less than the max frame length in effect
+ /// before calling this function, then the frame will be allowed.
+ pub fn set_max_frame_length(&mut self, val: usize) {
+ self.builder.max_frame_length(val);
+ }
+
+ fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
+ let head_len = self.builder.num_head_bytes();
+ let field_len = self.builder.length_field_len;
+
+ if src.len() < head_len {
+ // Not enough data
+ return Ok(None);
+ }
+
+ let n = {
+ let mut src = Cursor::new(&mut *src);
+
+ // Skip the required bytes
+ src.advance(self.builder.length_field_offset);
+
+ // match endianness
+ let n = if self.builder.length_field_is_big_endian {
+ src.get_uint(field_len)
+ } else {
+ src.get_uint_le(field_len)
+ };
+
+ if n > self.builder.max_frame_len as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ LengthDelimitedCodecError { _priv: () },
+ ));
+ }
+
+ // The check above ensures there is no overflow
+ let n = n as usize;
+
+ // Adjust `n` with bounds checking
+ let n = if self.builder.length_adjustment < 0 {
+ n.checked_sub(-self.builder.length_adjustment as usize)
+ } else {
+ n.checked_add(self.builder.length_adjustment as usize)
+ };
+
+ // Error handling
+ match n {
+ Some(n) => n,
+ None => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "provided length would overflow after adjustment",
+ ));
+ }
+ }
+ };
+
+ src.advance(self.builder.get_num_skip());
+
+ // Ensure that the buffer has enough space to read the incoming
+ // payload
+ src.reserve(n.saturating_sub(src.len()));
+
+ Ok(Some(n))
+ }
+
+ fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option<BytesMut> {
+ // At this point, the buffer has already had the required capacity
+ // reserved. All there is to do is read.
+ if src.len() < n {
+ return None;
+ }
+
+ Some(src.split_to(n))
+ }
+}
+
+impl Decoder for LengthDelimitedCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
+ let n = match self.state {
+ DecodeState::Head => match self.decode_head(src)? {
+ Some(n) => {
+ self.state = DecodeState::Data(n);
+ n
+ }
+ None => return Ok(None),
+ },
+ DecodeState::Data(n) => n,
+ };
+
+ match self.decode_data(n, src) {
+ Some(data) => {
+ // Update the decode state
+ self.state = DecodeState::Head;
+
+ // Make sure the buffer has enough space to read the next head
+ src.reserve(self.builder.num_head_bytes().saturating_sub(src.len()));
+
+ Ok(Some(data))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+impl Encoder<Bytes> for LengthDelimitedCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
+ let n = data.len();
+
+ if n > self.builder.max_frame_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ LengthDelimitedCodecError { _priv: () },
+ ));
+ }
+
+ // Adjust `n` with bounds checking
+ let n = if self.builder.length_adjustment < 0 {
+ n.checked_add(-self.builder.length_adjustment as usize)
+ } else {
+ n.checked_sub(self.builder.length_adjustment as usize)
+ };
+
+ let n = n.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "provided length would overflow after adjustment",
+ )
+ })?;
+
+ // Reserve capacity in the destination buffer to fit the frame and
+ // length field (plus adjustment).
+ dst.reserve(self.builder.length_field_len + n);
+
+ if self.builder.length_field_is_big_endian {
+ dst.put_uint(n as u64, self.builder.length_field_len);
+ } else {
+ dst.put_uint_le(n as u64, self.builder.length_field_len);
+ }
+
+ // Write the frame to the buffer
+ dst.extend_from_slice(&data[..]);
+
+ Ok(())
+ }
+}
+
+impl Default for LengthDelimitedCodec {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// ===== impl Builder =====
+
+mod builder {
+ /// Types that can be used with `Builder::length_field_type`.
+ pub trait LengthFieldType {}
+
+ impl LengthFieldType for u8 {}
+ impl LengthFieldType for u16 {}
+ impl LengthFieldType for u32 {}
+ impl LengthFieldType for u64 {}
+
+ #[cfg(any(
+ target_pointer_width = "8",
+ target_pointer_width = "16",
+ target_pointer_width = "32",
+ target_pointer_width = "64",
+ ))]
+ impl LengthFieldType for usize {}
+}
+
+impl Builder {
+ /// Creates a new length delimited codec builder with default configuration
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new() -> Builder {
+ Builder {
+ // Default max frame length of 8MB
+ max_frame_len: 8 * 1_024 * 1_024,
+
+ // Default byte length of 4
+ length_field_len: 4,
+
+ // Default to the header field being at the start of the header.
+ length_field_offset: 0,
+
+ length_adjustment: 0,
+
+ // Total number of bytes to skip before reading the payload, if not set,
+ // `length_field_len + length_field_offset`
+ num_skip: None,
+
+ // Default to reading the length field in network (big) endian.
+ length_field_is_big_endian: true,
+ }
+ }
+
+ /// Read the length field as a big endian integer
+ ///
+ /// This is the default setting.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .big_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn big_endian(&mut self) -> &mut Self {
+ self.length_field_is_big_endian = true;
+ self
+ }
+
+ /// Read the length field as a little endian integer
+ ///
+ /// The default setting is big endian.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .little_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn little_endian(&mut self) -> &mut Self {
+ self.length_field_is_big_endian = false;
+ self
+ }
+
+ /// Read the length field as a native endian integer
+ ///
+ /// The default setting is big endian.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .native_endian()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn native_endian(&mut self) -> &mut Self {
+ if cfg!(target_endian = "big") {
+ self.big_endian()
+ } else {
+ self.little_endian()
+ }
+ }
+
+ /// Sets the max frame length in bytes
+ ///
+ /// This configuration option applies to both encoding and decoding. The
+ /// default value is 8MB.
+ ///
+ /// When decoding, the length field read from the byte stream is checked
+ /// against this setting **before** any adjustments are applied. When
+ /// encoding, the length of the submitted payload is checked against this
+ /// setting.
+ ///
+ /// When frames exceed the max length, an `io::Error` with the custom value
+ /// of the `LengthDelimitedCodecError` type will be returned.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .max_frame_length(8 * 1024 * 1024)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn max_frame_length(&mut self, val: usize) -> &mut Self {
+ self.max_frame_len = val;
+ self
+ }
+
+ /// Sets the unsigned integer type used to represent the length field.
+ ///
+ /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on
+ /// 64-bit targets).
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u32>()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ ///
+ /// Unlike [`Builder::length_field_length`], this does not fail at runtime
+ /// and instead produces a compile error:
+ ///
+ /// ```compile_fail
+ /// # use tokio::io::AsyncRead;
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u128>()
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_type<T: builder::LengthFieldType>(&mut self) -> &mut Self {
+ self.length_field_length(mem::size_of::<T>())
+ }
+
+ /// Sets the number of bytes used to represent the length field
+ ///
+ /// The default value is `4`. The max value is `8`.
+ ///
+ /// This configuration option applies to both encoding and decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_length(4)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_length(&mut self, val: usize) -> &mut Self {
+ assert!(val > 0 && val <= 8, "invalid length field length");
+ self.length_field_len = val;
+ self
+ }
+
+ /// Sets the number of bytes in the header before the length field
+ ///
+ /// This configuration option only applies to decoding.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(1)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_field_offset(&mut self, val: usize) -> &mut Self {
+ self.length_field_offset = val;
+ self
+ }
+
+ /// Delta between the payload length specified in the header and the real
+ /// payload length
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_adjustment(-2)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn length_adjustment(&mut self, val: isize) -> &mut Self {
+ self.length_adjustment = val;
+ self
+ }
+
+ /// Sets the number of bytes to skip before reading the payload
+ ///
+ /// Default value is `length_field_len + length_field_offset`
+ ///
+ /// This configuration option only applies to decoding
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .num_skip(4)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn num_skip(&mut self, val: usize) -> &mut Self {
+ self.num_skip = Some(val);
+ self
+ }
+
+ /// Create a configured length delimited `LengthDelimitedCodec`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ /// # pub fn main() {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_codec();
+ /// # }
+ /// ```
+ pub fn new_codec(&self) -> LengthDelimitedCodec {
+ LengthDelimitedCodec {
+ builder: *self,
+ state: DecodeState::Head,
+ }
+ }
+
+ /// Create a configured length delimited `FramedRead`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::LengthDelimitedCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_type::<u16>()
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec>
+ where
+ T: AsyncRead,
+ {
+ FramedRead::new(upstream, self.new_codec())
+ }
+
+ /// Create a configured length delimited `FramedWrite`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncWrite;
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn write_frame<T: AsyncWrite>(io: T) {
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u16>()
+ /// .new_write(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec>
+ where
+ T: AsyncWrite,
+ {
+ FramedWrite::new(inner, self.new_codec())
+ }
+
+ /// Create a configured length delimited `Framed`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::{AsyncRead, AsyncWrite};
+ /// # use tokio_util::codec::LengthDelimitedCodec;
+ /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) {
+ /// # let _ =
+ /// LengthDelimitedCodec::builder()
+ /// .length_field_type::<u16>()
+ /// .new_framed(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec>
+ where
+ T: AsyncRead + AsyncWrite,
+ {
+ Framed::new(inner, self.new_codec())
+ }
+
+ fn num_head_bytes(&self) -> usize {
+ let num = self.length_field_offset + self.length_field_len;
+ cmp::max(num, self.num_skip.unwrap_or(0))
+ }
+
+ fn get_num_skip(&self) -> usize {
+ self.num_skip
+ .unwrap_or(self.length_field_offset + self.length_field_len)
+ }
+}
+
+impl Default for Builder {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// ===== impl LengthDelimitedCodecError =====
+
+impl fmt::Debug for LengthDelimitedCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("LengthDelimitedCodecError").finish()
+ }
+}
+
+impl fmt::Display for LengthDelimitedCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("frame size too big")
+ }
+}
+
+impl StdError for LengthDelimitedCodecError {}
diff --git a/src/codec/lines_codec.rs b/src/codec/lines_codec.rs
new file mode 100644
index 0000000..7a0a8f0
--- /dev/null
+++ b/src/codec/lines_codec.rs
@@ -0,0 +1,230 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use bytes::{Buf, BufMut, BytesMut};
+use std::{cmp, fmt, io, str, usize};
+
+/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines.
+///
+/// [`Decoder`]: crate::codec::Decoder
+/// [`Encoder`]: crate::codec::Encoder
+#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+pub struct LinesCodec {
+ // Stored index of the next index to examine for a `\n` character.
+ // This is used to optimize searching.
+ // For example, if `decode` was called with `abc`, it would hold `3`,
+ // because that is the next index to examine.
+ // The next time `decode` is called with `abcde\n`, the method will
+ // only look at `de\n` before returning.
+ next_index: usize,
+
+ /// The maximum length for a given line. If `usize::MAX`, lines will be
+ /// read until a `\n` character is reached.
+ max_length: usize,
+
+ /// Are we currently discarding the remainder of a line which was over
+ /// the length limit?
+ is_discarding: bool,
+}
+
+impl LinesCodec {
+ /// Returns a `LinesCodec` for splitting up data into lines.
+ ///
+ /// # Note
+ ///
+ /// The returned `LinesCodec` will not have an upper bound on the length
+ /// of a buffered line. See the documentation for [`new_with_max_length`]
+ /// for information on why this could be a potential security risk.
+ ///
+ /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length()
+ pub fn new() -> LinesCodec {
+ LinesCodec {
+ next_index: 0,
+ max_length: usize::MAX,
+ is_discarding: false,
+ }
+ }
+
+ /// Returns a `LinesCodec` with a maximum line length limit.
+ ///
+ /// If this is set, calls to `LinesCodec::decode` will return a
+ /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls
+ /// will discard up to `limit` bytes from that line until a newline
+ /// character is reached, returning `None` until the line over the limit
+ /// has been fully discarded. After that point, calls to `decode` will
+ /// function as normal.
+ ///
+ /// # Note
+ ///
+ /// Setting a length limit is highly recommended for any `LinesCodec` which
+ /// will be exposed to untrusted input. Otherwise, the size of the buffer
+ /// that holds the line currently being read is unbounded. An attacker could
+ /// exploit this unbounded buffer by sending an unbounded amount of input
+ /// without any `\n` characters, causing unbounded memory consumption.
+ ///
+ /// [`LinesCodecError`]: crate::codec::LinesCodecError
+ pub fn new_with_max_length(max_length: usize) -> Self {
+ LinesCodec {
+ max_length,
+ ..LinesCodec::new()
+ }
+ }
+
+ /// Returns the maximum line length when decoding.
+ ///
+ /// ```
+ /// use std::usize;
+ /// use tokio_util::codec::LinesCodec;
+ ///
+ /// let codec = LinesCodec::new();
+ /// assert_eq!(codec.max_length(), usize::MAX);
+ /// ```
+ /// ```
+ /// use tokio_util::codec::LinesCodec;
+ ///
+ /// let codec = LinesCodec::new_with_max_length(256);
+ /// assert_eq!(codec.max_length(), 256);
+ /// ```
+ pub fn max_length(&self) -> usize {
+ self.max_length
+ }
+}
+
+fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
+ str::from_utf8(buf)
+ .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
+}
+
+fn without_carriage_return(s: &[u8]) -> &[u8] {
+ if let Some(&b'\r') = s.last() {
+ &s[..s.len() - 1]
+ } else {
+ s
+ }
+}
+
+impl Decoder for LinesCodec {
+ type Item = String;
+ type Error = LinesCodecError;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
+ loop {
+ // Determine how far into the buffer we'll search for a newline. If
+ // there's no max_length set, we'll read to the end of the buffer.
+ let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
+
+ let newline_offset = buf[self.next_index..read_to]
+ .iter()
+ .position(|b| *b == b'\n');
+
+ match (self.is_discarding, newline_offset) {
+ (true, Some(offset)) => {
+ // If we found a newline, discard up to that offset and
+ // then stop discarding. On the next iteration, we'll try
+ // to read a line normally.
+ buf.advance(offset + self.next_index + 1);
+ self.is_discarding = false;
+ self.next_index = 0;
+ }
+ (true, None) => {
+ // Otherwise, we didn't find a newline, so we'll discard
+ // everything we read. On the next iteration, we'll continue
+ // discarding up to max_len bytes unless we find a newline.
+ buf.advance(read_to);
+ self.next_index = 0;
+ if buf.is_empty() {
+ return Ok(None);
+ }
+ }
+ (false, Some(offset)) => {
+ // Found a line!
+ let newline_index = offset + self.next_index;
+ self.next_index = 0;
+ let line = buf.split_to(newline_index + 1);
+ let line = &line[..line.len() - 1];
+ let line = without_carriage_return(line);
+ let line = utf8(line)?;
+ return Ok(Some(line.to_string()));
+ }
+ (false, None) if buf.len() > self.max_length => {
+ // Reached the maximum length without finding a
+ // newline, return an error and start discarding on the
+ // next call.
+ self.is_discarding = true;
+ return Err(LinesCodecError::MaxLineLengthExceeded);
+ }
+ (false, None) => {
+ // We didn't find a line or reach the length limit, so the next
+ // call will resume searching at the current offset.
+ self.next_index = read_to;
+ return Ok(None);
+ }
+ }
+ }
+ }
+
+ fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
+ Ok(match self.decode(buf)? {
+ Some(frame) => Some(frame),
+ None => {
+ // No terminating newline - return remaining data, if any
+ if buf.is_empty() || buf == &b"\r"[..] {
+ None
+ } else {
+ let line = buf.split_to(buf.len());
+ let line = without_carriage_return(&line);
+ let line = utf8(line)?;
+ self.next_index = 0;
+ Some(line.to_string())
+ }
+ }
+ })
+ }
+}
+
+impl<T> Encoder<T> for LinesCodec
+where
+ T: AsRef<str>,
+{
+ type Error = LinesCodecError;
+
+ fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> {
+ let line = line.as_ref();
+ buf.reserve(line.len() + 1);
+ buf.put(line.as_bytes());
+ buf.put_u8(b'\n');
+ Ok(())
+ }
+}
+
+impl Default for LinesCodec {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// An error occurred while encoding or decoding a line.
+#[derive(Debug)]
+pub enum LinesCodecError {
+ /// The maximum line length was exceeded.
+ MaxLineLengthExceeded,
+ /// An IO error occurred.
+ Io(io::Error),
+}
+
+impl fmt::Display for LinesCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"),
+ LinesCodecError::Io(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+impl From<io::Error> for LinesCodecError {
+ fn from(e: io::Error) -> LinesCodecError {
+ LinesCodecError::Io(e)
+ }
+}
+
+impl std::error::Error for LinesCodecError {}
diff --git a/src/codec/mod.rs b/src/codec/mod.rs
new file mode 100644
index 0000000..2295176
--- /dev/null
+++ b/src/codec/mod.rs
@@ -0,0 +1,290 @@
+//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink
+//!
+//! Raw I/O objects work with byte sequences, but higher-level code usually
+//! wants to batch these into meaningful chunks, called "frames".
+//!
+//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and
+//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`].
+//! Framed streams are also known as transports.
+//!
+//! # The Decoder trait
+//!
+//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an
+//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify
+//! how sequences of bytes are turned into a sequence of frames, and to
+//! determine where the boundaries between frames are. The job of the
+//! `FramedRead` is to repeatedly switch between reading more data from the IO
+//! resource, and asking the decoder whether we have received enough data to
+//! decode another frame of data.
+//!
+//! The main method on the `Decoder` trait is the [`decode`] method. This method
+//! takes as argument the data that has been read so far, and when it is called,
+//! it will be in one of the following situations:
+//!
+//! 1. The buffer contains less than a full frame.
+//! 2. The buffer contains exactly a full frame.
+//! 3. The buffer contains more than a full frame.
+//!
+//! In the first situation, the decoder should return `Ok(None)`.
+//!
+//! In the second situation, the decoder should clear the provided buffer and
+//! return `Ok(Some(the_decoded_frame))`.
+//!
+//! In the third situation, the decoder should use a method such as [`split_to`]
+//! or [`advance`] to modify the buffer such that the frame is removed from the
+//! buffer, but any data in the buffer after that frame should still remain in
+//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in
+//! this case.
+//!
+//! Finally the decoder may return an error if the data is invalid in some way.
+//! The decoder should _not_ return an error just because it has yet to receive
+//! a full frame.
+//!
+//! It is guaranteed that, from one call to `decode` to another, the provided
+//! buffer will contain the exact same data as before, except that if more data
+//! has arrived through the IO resource, that data will have been appended to
+//! the buffer. This means that reading frames from a `FramedRead` is
+//! essentially equivalent to the following loop:
+//!
+//! ```no_run
+//! use tokio::io::AsyncReadExt;
+//! # // This uses async_stream to create an example that compiles.
+//! # fn foo() -> impl futures_core::Stream<Item = std::io::Result<bytes::BytesMut>> { async_stream::try_stream! {
+//! # use tokio_util::codec::Decoder;
+//! # let mut decoder = tokio_util::codec::BytesCodec::new();
+//! # let io_resource = &mut &[0u8, 1, 2, 3][..];
+//!
+//! let mut buf = bytes::BytesMut::new();
+//! loop {
+//! // The read_buf call will append to buf rather than overwrite existing data.
+//! let len = io_resource.read_buf(&mut buf).await?;
+//!
+//! if len == 0 {
+//! while let Some(frame) = decoder.decode_eof(&mut buf)? {
+//! yield frame;
+//! }
+//! break;
+//! }
+//!
+//! while let Some(frame) = decoder.decode(&mut buf)? {
+//! yield frame;
+//! }
+//! }
+//! # }}
+//! ```
+//! The example above uses `yield` whenever the `Stream` produces an item.
+//!
+//! ## Example decoder
+//!
+//! As an example, consider a protocol that can be used to send strings where
+//! each frame is a four byte integer that contains the length of the frame,
+//! followed by that many bytes of string data. The decoder fails with an error
+//! if the string data is not valid utf-8 or too long.
+//!
+//! Such a decoder can be written like this:
+//! ```
+//! use tokio_util::codec::Decoder;
+//! use bytes::{BytesMut, Buf};
+//!
+//! struct MyStringDecoder {}
+//!
+//! const MAX: usize = 8 * 1024 * 1024;
+//!
+//! impl Decoder for MyStringDecoder {
+//! type Item = String;
+//! type Error = std::io::Error;
+//!
+//! fn decode(
+//! &mut self,
+//! src: &mut BytesMut
+//! ) -> Result<Option<Self::Item>, Self::Error> {
+//! if src.len() < 4 {
+//! // Not enough data to read length marker.
+//! return Ok(None);
+//! }
+//!
+//! // Read length marker.
+//! let mut length_bytes = [0u8; 4];
+//! length_bytes.copy_from_slice(&src[..4]);
+//! let length = u32::from_le_bytes(length_bytes) as usize;
+//!
+//! // Check that the length is not too large to avoid a denial of
+//! // service attack where the server runs out of memory.
+//! if length > MAX {
+//! return Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! format!("Frame of length {} is too large.", length)
+//! ));
+//! }
+//!
+//! if src.len() < 4 + length {
+//! // The full string has not yet arrived.
+//! //
+//! // We reserve more space in the buffer. This is not strictly
+//! // necessary, but is a good idea performance-wise.
+//! src.reserve(4 + length - src.len());
+//!
+//! // We inform the Framed that we need more bytes to form the next
+//! // frame.
+//! return Ok(None);
+//! }
+//!
+//! // Use advance to modify src such that it no longer contains
+//! // this frame.
+//! let data = src[4..4 + length].to_vec();
+//! src.advance(4 + length);
+//!
+//! // Convert the data to a string, or fail if it is not valid utf-8.
+//! match String::from_utf8(data) {
+//! Ok(string) => Ok(Some(string)),
+//! Err(utf8_error) => {
+//! Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! utf8_error.utf8_error(),
+//! ))
+//! },
+//! }
+//! }
+//! }
+//! ```
+//!
+//! # The Encoder trait
+//!
+//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn
+//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to
+//! specify how frames are turned into a sequences of bytes. The job of the
+//! `FramedWrite` is to take the resulting sequence of bytes and write it to the
+//! IO resource.
+//!
+//! The main method on the `Encoder` trait is the [`encode`] method. This method
+//! takes an item that is being written, and a buffer to write the item to. The
+//! buffer may already contain data, and in this case, the encoder should append
+//! the new frame the to buffer rather than overwrite the existing data.
+//!
+//! It is guaranteed that, from one call to `encode` to another, the provided
+//! buffer will contain the exact same data as before, except that some of the
+//! data may have been removed from the front of the buffer. Writing to a
+//! `FramedWrite` is essentially equivalent to the following loop:
+//!
+//! ```no_run
+//! use tokio::io::AsyncWriteExt;
+//! use bytes::Buf; // for advance
+//! # use tokio_util::codec::Encoder;
+//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() }
+//! # async fn no_more_frames() { }
+//! # #[tokio::main] async fn main() -> std::io::Result<()> {
+//! # let mut io_resource = tokio::io::sink();
+//! # let mut encoder = tokio_util::codec::BytesCodec::new();
+//!
+//! const MAX: usize = 8192;
+//!
+//! let mut buf = bytes::BytesMut::new();
+//! loop {
+//! tokio::select! {
+//! num_written = io_resource.write(&buf), if !buf.is_empty() => {
+//! buf.advance(num_written?);
+//! },
+//! frame = next_frame(), if buf.len() < MAX => {
+//! encoder.encode(frame, &mut buf)?;
+//! },
+//! _ = no_more_frames() => {
+//! io_resource.write_all(&buf).await?;
+//! io_resource.shutdown().await?;
+//! return Ok(());
+//! },
+//! }
+//! }
+//! # }
+//! ```
+//! Here the `next_frame` method corresponds to any frames you write to the
+//! `FramedWrite`. The `no_more_frames` method corresponds to closing the
+//! `FramedWrite` with [`SinkExt::close`].
+//!
+//! ## Example encoder
+//!
+//! As an example, consider a protocol that can be used to send strings where
+//! each frame is a four byte integer that contains the length of the frame,
+//! followed by that many bytes of string data. The encoder will fail if the
+//! string is too long.
+//!
+//! Such an encoder can be written like this:
+//! ```
+//! use tokio_util::codec::Encoder;
+//! use bytes::BytesMut;
+//!
+//! struct MyStringEncoder {}
+//!
+//! const MAX: usize = 8 * 1024 * 1024;
+//!
+//! impl Encoder<String> for MyStringEncoder {
+//! type Error = std::io::Error;
+//!
+//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> {
+//! // Don't send a string if it is longer than the other end will
+//! // accept.
+//! if item.len() > MAX {
+//! return Err(std::io::Error::new(
+//! std::io::ErrorKind::InvalidData,
+//! format!("Frame of length {} is too large.", item.len())
+//! ));
+//! }
+//!
+//! // Convert the length into a byte array.
+//! // The cast to u32 cannot overflow due to the length check above.
+//! let len_slice = u32::to_le_bytes(item.len() as u32);
+//!
+//! // Reserve space in the buffer.
+//! dst.reserve(4 + item.len());
+//!
+//! // Write the length and string to the buffer.
+//! dst.extend_from_slice(&len_slice);
+//! dst.extend_from_slice(item.as_bytes());
+//! Ok(())
+//! }
+//! }
+//! ```
+//!
+//! [`AsyncRead`]: tokio::io::AsyncRead
+//! [`AsyncWrite`]: tokio::io::AsyncWrite
+//! [`Stream`]: futures_core::Stream
+//! [`Sink`]: futures_sink::Sink
+//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close
+//! [`FramedRead`]: struct@crate::codec::FramedRead
+//! [`FramedWrite`]: struct@crate::codec::FramedWrite
+//! [`Framed`]: struct@crate::codec::Framed
+//! [`Decoder`]: trait@crate::codec::Decoder
+//! [`decode`]: fn@crate::codec::Decoder::decode
+//! [`encode`]: fn@crate::codec::Encoder::encode
+//! [`split_to`]: fn@bytes::BytesMut::split_to
+//! [`advance`]: fn@bytes::Buf::advance
+
+mod bytes_codec;
+pub use self::bytes_codec::BytesCodec;
+
+mod decoder;
+pub use self::decoder::Decoder;
+
+mod encoder;
+pub use self::encoder::Encoder;
+
+mod framed_impl;
+#[allow(unused_imports)]
+pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
+
+mod framed;
+pub use self::framed::{Framed, FramedParts};
+
+mod framed_read;
+pub use self::framed_read::FramedRead;
+
+mod framed_write;
+pub use self::framed_write::FramedWrite;
+
+pub mod length_delimited;
+pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError};
+
+mod lines_codec;
+pub use self::lines_codec::{LinesCodec, LinesCodecError};
+
+mod any_delimiter_codec;
+pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError};
diff --git a/src/compat.rs b/src/compat.rs
new file mode 100644
index 0000000..6a8802d
--- /dev/null
+++ b/src/compat.rs
@@ -0,0 +1,274 @@
+//! Compatibility between the `tokio::io` and `futures-io` versions of the
+//! `AsyncRead` and `AsyncWrite` traits.
+use futures_core::ready;
+use pin_project_lite::pin_project;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A compatibility layer that allows conversion between the
+ /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
+ #[derive(Copy, Clone, Debug)]
+ pub struct Compat<T> {
+ #[pin]
+ inner: T,
+ seek_pos: Option<io::SeekFrom>,
+ }
+}
+
+/// Extension trait that allows converting a type implementing
+/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
+pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `tokio_io::AsyncRead`.
+ fn compat(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
+pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `tokio::io::AsyncWrite`.
+ fn compat_write(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
+pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `futures_io::AsyncRead`.
+ fn compat(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
+
+/// Extension trait that allows converting a type implementing
+/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
+pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
+ /// Wraps `self` with a compatibility layer that implements
+ /// `futures_io::AsyncWrite`.
+ fn compat_write(self) -> Compat<Self>
+ where
+ Self: Sized,
+ {
+ Compat::new(self)
+ }
+}
+
+impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
+
+// === impl Compat ===
+
+impl<T> Compat<T> {
+ fn new(inner: T) -> Self {
+ Self {
+ inner,
+ seek_pos: None,
+ }
+ }
+
+ /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
+ /// contained within.
+ pub fn get_ref(&self) -> &T {
+ &self.inner
+ }
+
+ /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
+ /// contained within.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.inner
+ }
+
+ /// Returns the wrapped item.
+ pub fn into_inner(self) -> T {
+ self.inner
+ }
+}
+
+impl<T> tokio::io::AsyncRead for Compat<T>
+where
+ T: futures_io::AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ // We can't trust the inner type to not peak at the bytes,
+ // so we must defensively initialize the buffer.
+ let slice = buf.initialize_unfilled();
+ let n = ready!(futures_io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ slice
+ ))?;
+ buf.advance(n);
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<T> futures_io::AsyncRead for Compat<T>
+where
+ T: tokio::io::AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ slice: &mut [u8],
+ ) -> Poll<io::Result<usize>> {
+ let mut buf = tokio::io::ReadBuf::new(slice);
+ ready!(tokio::io::AsyncRead::poll_read(
+ self.project().inner,
+ cx,
+ &mut buf
+ ))?;
+ Poll::Ready(Ok(buf.filled().len()))
+ }
+}
+
+impl<T> tokio::io::AsyncBufRead for Compat<T>
+where
+ T: futures_io::AsyncBufRead,
+{
+ fn poll_fill_buf<'a>(
+ self: Pin<&'a mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<&'a [u8]>> {
+ futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ futures_io::AsyncBufRead::consume(self.project().inner, amt)
+ }
+}
+
+impl<T> futures_io::AsyncBufRead for Compat<T>
+where
+ T: tokio::io::AsyncBufRead,
+{
+ fn poll_fill_buf<'a>(
+ self: Pin<&'a mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<io::Result<&'a [u8]>> {
+ tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ tokio::io::AsyncBufRead::consume(self.project().inner, amt)
+ }
+}
+
+impl<T> tokio::io::AsyncWrite for Compat<T>
+where
+ T: futures_io::AsyncWrite,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ futures_io::AsyncWrite::poll_close(self.project().inner, cx)
+ }
+}
+
+impl<T> futures_io::AsyncWrite for Compat<T>
+where
+ T: tokio::io::AsyncWrite,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
+ }
+}
+
+impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
+ fn poll_seek(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ pos: io::SeekFrom,
+ ) -> Poll<io::Result<u64>> {
+ if self.seek_pos != Some(pos) {
+ self.as_mut().project().inner.start_seek(pos)?;
+ *self.as_mut().project().seek_pos = Some(pos);
+ }
+ let res = ready!(self.as_mut().project().inner.poll_complete(cx));
+ *self.as_mut().project().seek_pos = None;
+ Poll::Ready(res.map(|p| p as u64))
+ }
+}
+
+impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
+ fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
+ *self.as_mut().project().seek_pos = Some(pos);
+ Ok(())
+ }
+
+ fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
+ let pos = match self.seek_pos {
+ None => {
+ // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
+ // We don't have to guarantee that the value returned by
+ // poll_complete called without start_seek is correct,
+ // so we'll return 0.
+ return Poll::Ready(Ok(0));
+ }
+ Some(pos) => pos,
+ };
+ let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
+ *self.as_mut().project().seek_pos = None;
+ Poll::Ready(res.map(|p| p as u64))
+ }
+}
+
+#[cfg(unix)]
+impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
+ fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
+ self.inner.as_raw_fd()
+ }
+}
+
+#[cfg(windows)]
+impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
+ fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
+ self.inner.as_raw_handle()
+ }
+}
diff --git a/src/context.rs b/src/context.rs
new file mode 100644
index 0000000..a7a5e02
--- /dev/null
+++ b/src/context.rs
@@ -0,0 +1,190 @@
+//! Tokio context aware futures utilities.
+//!
+//! This module includes utilities around integrating tokio with other runtimes
+//! by allowing the context to be attached to futures. This allows spawning
+//! futures on other executors while still using tokio to drive them. This
+//! can be useful if you need to use a tokio based library in an executor/runtime
+//! that does not provide a tokio context.
+
+use pin_project_lite::pin_project;
+use std::{
+ future::Future,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::runtime::{Handle, Runtime};
+
+pin_project! {
+ /// `TokioContext` allows running futures that must be inside Tokio's
+ /// context on a non-Tokio runtime.
+ ///
+ /// It contains a [`Handle`] to the runtime. A handle to the runtime can be
+ /// obtain by calling the [`Runtime::handle()`] method.
+ ///
+ /// Note that the `TokioContext` wrapper only works if the `Runtime` it is
+ /// connected to has not yet been destroyed. You must keep the `Runtime`
+ /// alive until the future has finished executing.
+ ///
+ /// **Warning:** If `TokioContext` is used together with a [current thread]
+ /// runtime, that runtime must be inside a call to `block_on` for the
+ /// wrapped future to work. For this reason, it is recommended to use a
+ /// [multi thread] runtime, even if you configure it to only spawn one
+ /// worker thread.
+ ///
+ /// # Examples
+ ///
+ /// This example creates two runtimes, but only [enables time] on one of
+ /// them. It then uses the context of the runtime with the timer enabled to
+ /// execute a [`sleep`] future on the runtime with timing disabled.
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::RuntimeExt;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // Wrap the sleep future in the context of rt.
+ /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ ///
+ /// [`Handle`]: struct@tokio::runtime::Handle
+ /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle
+ /// [`RuntimeExt`]: trait@crate::context::RuntimeExt
+ /// [`new_static`]: fn@Self::new_static
+ /// [`sleep`]: fn@tokio::time::sleep
+ /// [current thread]: fn@tokio::runtime::Builder::new_current_thread
+ /// [enables time]: fn@tokio::runtime::Builder::enable_time
+ /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread
+ pub struct TokioContext<F> {
+ #[pin]
+ inner: F,
+ handle: Handle,
+ }
+}
+
+impl<F> TokioContext<F> {
+ /// Associate the provided future with the context of the runtime behind
+ /// the provided `Handle`.
+ ///
+ /// This constructor uses a `'static` lifetime to opt-out of checking that
+ /// the runtime still exists.
+ ///
+ /// # Examples
+ ///
+ /// This is the same as the example above, but uses the `new` constructor
+ /// rather than [`RuntimeExt::wrap`].
+ ///
+ /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap
+ ///
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::TokioContext;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// let fut = TokioContext::new(
+ /// async { sleep(Duration::from_millis(2)).await },
+ /// rt.handle().clone(),
+ /// );
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ pub fn new(future: F, handle: Handle) -> TokioContext<F> {
+ TokioContext {
+ inner: future,
+ handle,
+ }
+ }
+
+ /// Obtain a reference to the handle inside this `TokioContext`.
+ pub fn handle(&self) -> &Handle {
+ &self.handle
+ }
+
+ /// Remove the association between the Tokio runtime and the wrapped future.
+ pub fn into_inner(self) -> F {
+ self.inner
+ }
+}
+
+impl<F: Future> Future for TokioContext<F> {
+ type Output = F::Output;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ let handle = me.handle;
+ let fut = me.inner;
+
+ let _enter = handle.enter();
+ fut.poll(cx)
+ }
+}
+
+/// Extension trait that simplifies bundling a `Handle` with a `Future`.
+pub trait RuntimeExt {
+ /// Create a [`TokioContext`] that wraps the provided future and runs it in
+ /// this runtime's context.
+ ///
+ /// # Examples
+ ///
+ /// This example creates two runtimes, but only [enables time] on one of
+ /// them. It then uses the context of the runtime with the timer enabled to
+ /// execute a [`sleep`] future on the runtime with timing disabled.
+ ///
+ /// ```
+ /// use tokio::time::{sleep, Duration};
+ /// use tokio_util::context::RuntimeExt;
+ ///
+ /// // This runtime has timers enabled.
+ /// let rt = tokio::runtime::Builder::new_multi_thread()
+ /// .enable_all()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // This runtime has timers disabled.
+ /// let rt2 = tokio::runtime::Builder::new_multi_thread()
+ /// .build()
+ /// .unwrap();
+ ///
+ /// // Wrap the sleep future in the context of rt.
+ /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
+ ///
+ /// // Execute the future on rt2.
+ /// rt2.block_on(fut);
+ /// ```
+ ///
+ /// [`TokioContext`]: struct@crate::context::TokioContext
+ /// [`sleep`]: fn@tokio::time::sleep
+ /// [enables time]: fn@tokio::runtime::Builder::enable_time
+ fn wrap<F: Future>(&self, fut: F) -> TokioContext<F>;
+}
+
+impl RuntimeExt for Runtime {
+ fn wrap<F: Future>(&self, fut: F) -> TokioContext<F> {
+ TokioContext {
+ inner: fut,
+ handle: self.handle().clone(),
+ }
+ }
+}
diff --git a/src/either.rs b/src/either.rs
new file mode 100644
index 0000000..9225e53
--- /dev/null
+++ b/src/either.rs
@@ -0,0 +1,188 @@
+//! Module defining an Either type.
+use std::{
+ future::Future,
+ io::SeekFrom,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
+
+/// Combines two different futures, streams, or sinks having the same associated types into a single type.
+///
+/// This type implements common asynchronous traits such as [`Future`] and those in Tokio.
+///
+/// [`Future`]: std::future::Future
+///
+/// # Example
+///
+/// The following code will not work:
+///
+/// ```compile_fail
+/// # fn some_condition() -> bool { true }
+/// # async fn some_async_function() -> u32 { 10 }
+/// # async fn other_async_function() -> u32 { 20 }
+/// #[tokio::main]
+/// async fn main() {
+/// let result = if some_condition() {
+/// some_async_function()
+/// } else {
+/// other_async_function() // <- Will print: "`if` and `else` have incompatible types"
+/// };
+///
+/// println!("Result is {}", result.await);
+/// }
+/// ```
+///
+// This is because although the output types for both futures is the same, the exact future
+// types are different, but the compiler must be able to choose a single type for the
+// `result` variable.
+///
+/// When the output type is the same, we can wrap each future in `Either` to avoid the
+/// issue:
+///
+/// ```
+/// use tokio_util::either::Either;
+/// # fn some_condition() -> bool { true }
+/// # async fn some_async_function() -> u32 { 10 }
+/// # async fn other_async_function() -> u32 { 20 }
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let result = if some_condition() {
+/// Either::Left(some_async_function())
+/// } else {
+/// Either::Right(other_async_function())
+/// };
+///
+/// let value = result.await;
+/// println!("Result is {}", value);
+/// # assert_eq!(value, 10);
+/// }
+/// ```
+#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense.
+#[derive(Debug, Clone)]
+pub enum Either<L, R> {
+ Left(L),
+ Right(R),
+}
+
+/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation.
+/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either
+/// enum variant held in `self`.
+macro_rules! delegate_call {
+ ($self:ident.$method:ident($($args:ident),+)) => {
+ unsafe {
+ match $self.get_unchecked_mut() {
+ Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
+ Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
+ }
+ }
+ }
+}
+
+impl<L, R, O> Future for Either<L, R>
+where
+ L: Future<Output = O>,
+ R: Future<Output = O>,
+{
+ type Output = O;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ delegate_call!(self.poll(cx))
+ }
+}
+
+impl<L, R> AsyncRead for Either<L, R>
+where
+ L: AsyncRead,
+ R: AsyncRead,
+{
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<()>> {
+ delegate_call!(self.poll_read(cx, buf))
+ }
+}
+
+impl<L, R> AsyncBufRead for Either<L, R>
+where
+ L: AsyncBufRead,
+ R: AsyncBufRead,
+{
+ fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
+ delegate_call!(self.poll_fill_buf(cx))
+ }
+
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ delegate_call!(self.consume(amt))
+ }
+}
+
+impl<L, R> AsyncSeek for Either<L, R>
+where
+ L: AsyncSeek,
+ R: AsyncSeek,
+{
+ fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
+ delegate_call!(self.start_seek(position))
+ }
+
+ fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
+ delegate_call!(self.poll_complete(cx))
+ }
+}
+
+impl<L, R> AsyncWrite for Either<L, R>
+where
+ L: AsyncWrite,
+ R: AsyncWrite,
+{
+ fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
+ delegate_call!(self.poll_write(cx, buf))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
+ delegate_call!(self.poll_flush(cx))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
+ delegate_call!(self.poll_shutdown(cx))
+ }
+}
+
+impl<L, R> futures_core::stream::Stream for Either<L, R>
+where
+ L: futures_core::stream::Stream,
+ R: futures_core::stream::Stream<Item = L::Item>,
+{
+ type Item = L::Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ delegate_call!(self.poll_next(cx))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tokio::io::{repeat, AsyncReadExt, Repeat};
+ use tokio_stream::{once, Once, StreamExt};
+
+ #[tokio::test]
+ async fn either_is_stream() {
+ let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
+
+ assert_eq!(Some(1u32), either.next().await);
+ }
+
+ #[tokio::test]
+ async fn either_is_async_read() {
+ let mut buffer = [0; 3];
+ let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
+
+ either.read_exact(&mut buffer).await.unwrap();
+ assert_eq!(buffer, [0b101, 0b101, 0b101]);
+ }
+}
diff --git a/src/io/copy_to_bytes.rs b/src/io/copy_to_bytes.rs
new file mode 100644
index 0000000..9509e71
--- /dev/null
+++ b/src/io/copy_to_bytes.rs
@@ -0,0 +1,68 @@
+use bytes::Bytes;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+pin_project! {
+ /// A helper that wraps a [`Sink`]`<`[`Bytes`]`>` and converts it into a
+ /// [`Sink`]`<&'a [u8]>` by copying each byte slice into an owned [`Bytes`].
+ ///
+ /// See the documentation for [`SinkWriter`] for an example.
+ ///
+ /// [`Bytes`]: bytes::Bytes
+ /// [`SinkWriter`]: crate::io::SinkWriter
+ /// [`Sink`]: futures_sink::Sink
+ #[derive(Debug)]
+ pub struct CopyToBytes<S> {
+ #[pin]
+ inner: S,
+ }
+}
+
+impl<S> CopyToBytes<S> {
+ /// Creates a new [`CopyToBytes`].
+ pub fn new(inner: S) -> Self {
+ Self { inner }
+ }
+
+ /// Gets a reference to the underlying sink.
+ pub fn get_ref(&self) -> &S {
+ &self.inner
+ }
+
+ /// Gets a mutable reference to the underlying sink.
+ pub fn get_mut(&mut self) -> &mut S {
+ &mut self.inner
+ }
+
+ /// Consumes this [`CopyToBytes`], returning the underlying sink.
+ pub fn into_inner(self) -> S {
+ self.inner
+ }
+}
+
+impl<'a, S> Sink<&'a [u8]> for CopyToBytes<S>
+where
+ S: Sink<Bytes>,
+{
+ type Error = S::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: &'a [u8]) -> Result<(), Self::Error> {
+ self.project()
+ .inner
+ .start_send(Bytes::copy_from_slice(item))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.project().inner.poll_close(cx)
+ }
+}
diff --git a/src/io/inspect.rs b/src/io/inspect.rs
new file mode 100644
index 0000000..ec5bb97
--- /dev/null
+++ b/src/io/inspect.rs
@@ -0,0 +1,134 @@
+use futures_core::ready;
+use pin_project_lite::pin_project;
+use std::io::{IoSlice, Result};
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+pin_project! {
+ /// An adapter that lets you inspect the data that's being read.
+ ///
+ /// This is useful for things like hashing data as it's read in.
+ pub struct InspectReader<R, F> {
+ #[pin]
+ reader: R,
+ f: F,
+ }
+}
+
+impl<R, F> InspectReader<R, F> {
+ /// Create a new InspectReader, wrapping `reader` and calling `f` for the
+ /// new data supplied by each read call.
+ ///
+ /// The closure will only be called with an empty slice if the inner reader
+ /// returns without reading data into the buffer. This happens at EOF, or if
+ /// `poll_read` is called with a zero-size buffer.
+ pub fn new(reader: R, f: F) -> InspectReader<R, F>
+ where
+ R: AsyncRead,
+ F: FnMut(&[u8]),
+ {
+ InspectReader { reader, f }
+ }
+
+ /// Consumes the `InspectReader`, returning the wrapped reader
+ pub fn into_inner(self) -> R {
+ self.reader
+ }
+}
+
+impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<()>> {
+ let me = self.project();
+ let filled_length = buf.filled().len();
+ ready!(me.reader.poll_read(cx, buf))?;
+ (me.f)(&buf.filled()[filled_length..]);
+ Poll::Ready(Ok(()))
+ }
+}
+
+pin_project! {
+ /// An adapter that lets you inspect the data that's being written.
+ ///
+ /// This is useful for things like hashing data as it's written out.
+ pub struct InspectWriter<W, F> {
+ #[pin]
+ writer: W,
+ f: F,
+ }
+}
+
+impl<W, F> InspectWriter<W, F> {
+ /// Create a new InspectWriter, wrapping `write` and calling `f` for the
+ /// data successfully written by each write call.
+ ///
+ /// The closure `f` will never be called with an empty slice. A vectored
+ /// write can result in multiple calls to `f` - at most one call to `f` per
+ /// buffer supplied to `poll_write_vectored`.
+ pub fn new(writer: W, f: F) -> InspectWriter<W, F>
+ where
+ W: AsyncWrite,
+ F: FnMut(&[u8]),
+ {
+ InspectWriter { writer, f }
+ }
+
+ /// Consumes the `InspectWriter`, returning the wrapped writer
+ pub fn into_inner(self) -> W {
+ self.writer
+ }
+}
+
+impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
+ fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
+ let me = self.project();
+ let res = me.writer.poll_write(cx, buf);
+ if let Poll::Ready(Ok(count)) = res {
+ if count != 0 {
+ (me.f)(&buf[..count]);
+ }
+ }
+ res
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ let me = self.project();
+ me.writer.poll_flush(cx)
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ let me = self.project();
+ me.writer.poll_shutdown(cx)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[IoSlice<'_>],
+ ) -> Poll<Result<usize>> {
+ let me = self.project();
+ let res = me.writer.poll_write_vectored(cx, bufs);
+ if let Poll::Ready(Ok(mut count)) = res {
+ for buf in bufs {
+ if count == 0 {
+ break;
+ }
+ let size = count.min(buf.len());
+ if size != 0 {
+ (me.f)(&buf[..size]);
+ count -= size;
+ }
+ }
+ }
+ res
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ self.writer.is_write_vectored()
+ }
+}
diff --git a/src/io/mod.rs b/src/io/mod.rs
new file mode 100644
index 0000000..6c40d73
--- /dev/null
+++ b/src/io/mod.rs
@@ -0,0 +1,31 @@
+//! Helpers for IO related tasks.
+//!
+//! The stream types are often used in combination with hyper or reqwest, as they
+//! allow converting between a hyper [`Body`] and [`AsyncRead`].
+//!
+//! The [`SyncIoBridge`] type converts from the world of async I/O
+//! to synchronous I/O; this may often come up when using synchronous APIs
+//! inside [`tokio::task::spawn_blocking`].
+//!
+//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
+//! [`AsyncRead`]: tokio::io::AsyncRead
+
+mod copy_to_bytes;
+mod inspect;
+mod read_buf;
+mod reader_stream;
+mod sink_writer;
+mod stream_reader;
+
+cfg_io_util! {
+ mod sync_bridge;
+ pub use self::sync_bridge::SyncIoBridge;
+}
+
+pub use self::copy_to_bytes::CopyToBytes;
+pub use self::inspect::{InspectReader, InspectWriter};
+pub use self::read_buf::read_buf;
+pub use self::reader_stream::ReaderStream;
+pub use self::sink_writer::SinkWriter;
+pub use self::stream_reader::StreamReader;
+pub use crate::util::{poll_read_buf, poll_write_buf};
diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs
new file mode 100644
index 0000000..d7938a3
--- /dev/null
+++ b/src/io/read_buf.rs
@@ -0,0 +1,65 @@
+use bytes::BufMut;
+use std::future::Future;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+
+/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
+///
+/// [`BufMut`]: bytes::BufMut
+///
+/// # Example
+///
+/// ```
+/// use bytes::{Bytes, BytesMut};
+/// use tokio_stream as stream;
+/// use tokio::io::Result;
+/// use tokio_util::io::{StreamReader, read_buf};
+/// # #[tokio::main]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a reader from an iterator. This particular reader will always be
+/// // ready.
+/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
+///
+/// let mut buf = BytesMut::new();
+/// let mut reads = 0;
+///
+/// loop {
+/// reads += 1;
+/// let n = read_buf(&mut read, &mut buf).await?;
+///
+/// if n == 0 {
+/// break;
+/// }
+/// }
+///
+/// // one or more reads might be necessary.
+/// assert!(reads >= 1);
+/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
+/// # Ok(())
+/// # }
+/// ```
+pub async fn read_buf<R, B>(read: &mut R, buf: &mut B) -> io::Result<usize>
+where
+ R: AsyncRead + Unpin,
+ B: BufMut,
+{
+ return ReadBufFn(read, buf).await;
+
+ struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B);
+
+ impl<'a, R, B> Future for ReadBufFn<'a, R, B>
+ where
+ R: AsyncRead + Unpin,
+ B: BufMut,
+ {
+ type Output = io::Result<usize>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let this = &mut *self;
+ crate::util::poll_read_buf(Pin::new(this.0), cx, this.1)
+ }
+ }
+}
diff --git a/src/io/reader_stream.rs b/src/io/reader_stream.rs
new file mode 100644
index 0000000..866c114
--- /dev/null
+++ b/src/io/reader_stream.rs
@@ -0,0 +1,118 @@
+use bytes::{Bytes, BytesMut};
+use futures_core::stream::Stream;
+use pin_project_lite::pin_project;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncRead;
+
+const DEFAULT_CAPACITY: usize = 4096;
+
+pin_project! {
+ /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
+ ///
+ /// This stream is fused. It performs the inverse operation of
+ /// [`StreamReader`].
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// use tokio_stream::StreamExt;
+ /// use tokio_util::io::ReaderStream;
+ ///
+ /// // Create a stream of data.
+ /// let data = b"hello, world!";
+ /// let mut stream = ReaderStream::new(&data[..]);
+ ///
+ /// // Read all of the chunks into a vector.
+ /// let mut stream_contents = Vec::new();
+ /// while let Some(chunk) = stream.next().await {
+ /// stream_contents.extend_from_slice(&chunk?);
+ /// }
+ ///
+ /// // Once the chunks are concatenated, we should have the
+ /// // original data.
+ /// assert_eq!(stream_contents, data);
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`StreamReader`]: crate::io::StreamReader
+ /// [`Stream`]: futures_core::Stream
+ #[derive(Debug)]
+ pub struct ReaderStream<R> {
+ // Reader itself.
+ //
+ // This value is `None` if the stream has terminated.
+ #[pin]
+ reader: Option<R>,
+ // Working buffer, used to optimize allocations.
+ buf: BytesMut,
+ capacity: usize,
+ }
+}
+
+impl<R: AsyncRead> ReaderStream<R> {
+ /// Convert an [`AsyncRead`] into a [`Stream`] with item type
+ /// `Result<Bytes, std::io::Error>`.
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: futures_core::Stream
+ pub fn new(reader: R) -> Self {
+ ReaderStream {
+ reader: Some(reader),
+ buf: BytesMut::new(),
+ capacity: DEFAULT_CAPACITY,
+ }
+ }
+
+ /// Convert an [`AsyncRead`] into a [`Stream`] with item type
+ /// `Result<Bytes, std::io::Error>`,
+ /// with a specific read buffer initial capacity.
+ ///
+ /// [`AsyncRead`]: tokio::io::AsyncRead
+ /// [`Stream`]: futures_core::Stream
+ pub fn with_capacity(reader: R, capacity: usize) -> Self {
+ ReaderStream {
+ reader: Some(reader),
+ buf: BytesMut::with_capacity(capacity),
+ capacity,
+ }
+ }
+}
+
+impl<R: AsyncRead> Stream for ReaderStream<R> {
+ type Item = std::io::Result<Bytes>;
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ use crate::util::poll_read_buf;
+
+ let mut this = self.as_mut().project();
+
+ let reader = match this.reader.as_pin_mut() {
+ Some(r) => r,
+ None => return Poll::Ready(None),
+ };
+
+ if this.buf.capacity() == 0 {
+ this.buf.reserve(*this.capacity);
+ }
+
+ match poll_read_buf(reader, cx, &mut this.buf) {
+ Poll::Pending => Poll::Pending,
+ Poll::Ready(Err(err)) => {
+ self.project().reader.set(None);
+ Poll::Ready(Some(Err(err)))
+ }
+ Poll::Ready(Ok(0)) => {
+ self.project().reader.set(None);
+ Poll::Ready(None)
+ }
+ Poll::Ready(Ok(_)) => {
+ let chunk = this.buf.split();
+ Poll::Ready(Some(Ok(chunk.freeze())))
+ }
+ }
+ }
+}
diff --git a/src/io/sink_writer.rs b/src/io/sink_writer.rs
new file mode 100644
index 0000000..f2af262
--- /dev/null
+++ b/src/io/sink_writer.rs
@@ -0,0 +1,117 @@
+use futures_core::ready;
+use futures_sink::Sink;
+
+use pin_project_lite::pin_project;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::AsyncWrite;
+
+pin_project! {
+ /// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`].
+ ///
+ /// Whenever you write to this [`SinkWriter`], the supplied bytes are
+ /// forwarded to the inner [`Sink`]. When `shutdown` is called on this
+ /// [`SinkWriter`], the inner sink is closed.
+ ///
+ /// This adapter takes a `Sink<&[u8]>` and provides an [`AsyncWrite`] impl
+ /// for it. Because of the lifetime, this trait is relatively rarely
+ /// implemented. The main ways to get a `Sink<&[u8]>` that you can use with
+ /// this type are:
+ ///
+ /// * With the codec module by implementing the [`Encoder`]`<&[u8]>` trait.
+ /// * By wrapping a `Sink<Bytes>` in a [`CopyToBytes`].
+ /// * Manually implementing `Sink<&[u8]>` directly.
+ ///
+ /// The opposite conversion of implementing `Sink<_>` for an [`AsyncWrite`]
+ /// is done using the [`codec`] module.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use bytes::Bytes;
+ /// use futures_util::SinkExt;
+ /// use std::io::{Error, ErrorKind};
+ /// use tokio::io::AsyncWriteExt;
+ /// use tokio_util::io::{SinkWriter, CopyToBytes};
+ /// use tokio_util::sync::PollSender;
+ ///
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// # async fn main() -> Result<(), Error> {
+ /// // We use an mpsc channel as an example of a `Sink<Bytes>`.
+ /// let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(1);
+ /// let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
+ ///
+ /// // Wrap it in `CopyToBytes` to get a `Sink<&[u8]>`.
+ /// let mut writer = SinkWriter::new(CopyToBytes::new(sink));
+ ///
+ /// // Write data to our interface...
+ /// let data: [u8; 4] = [1, 2, 3, 4];
+ /// let _ = writer.write(&data).await?;
+ ///
+ /// // ... and receive it.
+ /// assert_eq!(data.as_slice(), &*rx.recv().await.unwrap());
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`AsyncWrite`]: tokio::io::AsyncWrite
+ /// [`CopyToBytes`]: crate::io::CopyToBytes
+ /// [`Encoder`]: crate::codec::Encoder
+ /// [`Sink`]: futures_sink::Sink
+ /// [`codec`]: tokio_util::codec
+ #[derive(Debug)]
+ pub struct SinkWriter<S> {
+ #[pin]
+ inner: S,
+ }
+}
+
+impl<S> SinkWriter<S> {
+ /// Creates a new [`SinkWriter`].
+ pub fn new(sink: S) -> Self {
+ Self { inner: sink }
+ }
+
+ /// Gets a reference to the underlying sink.
+ pub fn get_ref(&self) -> &S {
+ &self.inner
+ }
+
+ /// Gets a mutable reference to the underlying sink.
+ pub fn get_mut(&mut self) -> &mut S {
+ &mut self.inner
+ }
+
+ /// Consumes this [`SinkWriter`], returning the underlying sink.
+ pub fn into_inner(self) -> S {
+ self.inner
+ }
+}
+impl<S, E> AsyncWrite for SinkWriter<S>
+where
+ for<'a> S: Sink<&'a [u8], Error = E>,
+ E: Into<io::Error>,
+{
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ let mut this = self.project();
+
+ ready!(this.inner.as_mut().poll_ready(cx).map_err(Into::into))?;
+ match this.inner.as_mut().start_send(buf) {
+ Ok(()) => Poll::Ready(Ok(buf.len())),
+ Err(e) => Poll::Ready(Err(e.into())),
+ }
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.project().inner.poll_flush(cx).map_err(Into::into)
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.project().inner.poll_close(cx).map_err(Into::into)
+ }
+}
diff --git a/src/io/stream_reader.rs b/src/io/stream_reader.rs
new file mode 100644
index 0000000..3353722
--- /dev/null
+++ b/src/io/stream_reader.rs
@@ -0,0 +1,326 @@
+use bytes::Buf;
+use futures_core::stream::Stream;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
+
+/// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
+///
+/// This type performs the inverse operation of [`ReaderStream`].
+///
+/// This type also implements the [`AsyncBufRead`] trait, so you can use it
+/// to read a `Stream` of byte chunks line-by-line. See the examples below.
+///
+/// # Example
+///
+/// ```
+/// use bytes::Bytes;
+/// use tokio::io::{AsyncReadExt, Result};
+/// use tokio_util::io::StreamReader;
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a stream from an iterator.
+/// let stream = tokio_stream::iter(vec![
+/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
+/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
+/// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
+/// ]);
+///
+/// // Convert it to an AsyncRead.
+/// let mut read = StreamReader::new(stream);
+///
+/// // Read five bytes from the stream.
+/// let mut buf = [0; 5];
+/// read.read_exact(&mut buf).await?;
+/// assert_eq!(buf, [0, 1, 2, 3, 4]);
+///
+/// // Read the rest of the current chunk.
+/// assert_eq!(read.read(&mut buf).await?, 3);
+/// assert_eq!(&buf[..3], [5, 6, 7]);
+///
+/// // Read the next chunk.
+/// assert_eq!(read.read(&mut buf).await?, 4);
+/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
+///
+/// // We have now reached the end.
+/// assert_eq!(read.read(&mut buf).await?, 0);
+///
+/// # Ok(())
+/// # }
+/// ```
+///
+/// If the stream produces errors which are not [`std::io::Error`],
+/// the errors can be converted using [`StreamExt`] to map each
+/// element.
+///
+/// ```
+/// use bytes::Bytes;
+/// use tokio::io::AsyncReadExt;
+/// use tokio_util::io::StreamReader;
+/// use tokio_stream::StreamExt;
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a stream from an iterator, including an error.
+/// let stream = tokio_stream::iter(vec![
+/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
+/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
+/// Result::Err("Something bad happened!")
+/// ]);
+///
+/// // Use StreamExt to map the stream and error to a std::io::Error
+/// let stream = stream.map(|result| result.map_err(|err| {
+/// std::io::Error::new(std::io::ErrorKind::Other, err)
+/// }));
+///
+/// // Convert it to an AsyncRead.
+/// let mut read = StreamReader::new(stream);
+///
+/// // Read five bytes from the stream.
+/// let mut buf = [0; 5];
+/// read.read_exact(&mut buf).await?;
+/// assert_eq!(buf, [0, 1, 2, 3, 4]);
+///
+/// // Read the rest of the current chunk.
+/// assert_eq!(read.read(&mut buf).await?, 3);
+/// assert_eq!(&buf[..3], [5, 6, 7]);
+///
+/// // Reading the next chunk will produce an error
+/// let error = read.read(&mut buf).await.unwrap_err();
+/// assert_eq!(error.kind(), std::io::ErrorKind::Other);
+/// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
+///
+/// // We have now reached the end.
+/// assert_eq!(read.read(&mut buf).await?, 0);
+///
+/// # Ok(())
+/// # }
+/// ```
+///
+/// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
+/// line-by-line. Note that you will usually also need to convert the error
+/// type when doing this. See the second example for an explanation of how
+/// to do this.
+///
+/// ```
+/// use tokio::io::{Result, AsyncBufReadExt};
+/// use tokio_util::io::StreamReader;
+/// # #[tokio::main(flavor = "current_thread")]
+/// # async fn main() -> std::io::Result<()> {
+///
+/// // Create a stream of byte chunks.
+/// let stream = tokio_stream::iter(vec![
+/// Result::Ok(b"The first line.\n".as_slice()),
+/// Result::Ok(b"The second line.".as_slice()),
+/// Result::Ok(b"\nThe third".as_slice()),
+/// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
+/// ]);
+///
+/// // Convert it to an AsyncRead.
+/// let mut read = StreamReader::new(stream);
+///
+/// // Loop through the lines from the `StreamReader`.
+/// let mut line = String::new();
+/// let mut lines = Vec::new();
+/// loop {
+/// line.clear();
+/// let len = read.read_line(&mut line).await?;
+/// if len == 0 { break; }
+/// lines.push(line.clone());
+/// }
+///
+/// // Verify that we got the lines we expected.
+/// assert_eq!(
+/// lines,
+/// vec![
+/// "The first line.\n",
+/// "The second line.\n",
+/// "The third line.\n",
+/// "The fourth line.\n",
+/// "The fifth line.\n",
+/// ]
+/// );
+/// # Ok(())
+/// # }
+/// ```
+///
+/// [`AsyncRead`]: tokio::io::AsyncRead
+/// [`AsyncBufRead`]: tokio::io::AsyncBufRead
+/// [`Stream`]: futures_core::Stream
+/// [`ReaderStream`]: crate::io::ReaderStream
+/// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
+#[derive(Debug)]
+pub struct StreamReader<S, B> {
+ // This field is pinned.
+ inner: S,
+ // This field is not pinned.
+ chunk: Option<B>,
+}
+
+impl<S, B, E> StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead).
+ ///
+ /// The item should be a [`Result`] with the ok variant being something that
+ /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
+ /// should be convertible into an [io error].
+ ///
+ /// [`Result`]: std::result::Result
+ /// [`Buf`]: bytes::Buf
+ /// [io error]: std::io::Error
+ pub fn new(stream: S) -> Self {
+ Self {
+ inner: stream,
+ chunk: None,
+ }
+ }
+
+ /// Do we have a chunk and is it non-empty?
+ fn has_chunk(&self) -> bool {
+ if let Some(ref chunk) = self.chunk {
+ chunk.remaining() > 0
+ } else {
+ false
+ }
+ }
+
+ /// Consumes this `StreamReader`, returning a Tuple consisting
+ /// of the underlying stream and an Option of the internal buffer,
+ /// which is Some in case the buffer contains elements.
+ pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
+ if self.has_chunk() {
+ (self.inner, self.chunk)
+ } else {
+ (self.inner, None)
+ }
+ }
+}
+
+impl<S, B> StreamReader<S, B> {
+ /// Gets a reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_ref(&self) -> &S {
+ &self.inner
+ }
+
+ /// Gets a mutable reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_mut(&mut self) -> &mut S {
+ &mut self.inner
+ }
+
+ /// Gets a pinned mutable reference to the underlying stream.
+ ///
+ /// It is inadvisable to directly read from the underlying stream.
+ pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
+ self.project().inner
+ }
+
+ /// Consumes this `BufWriter`, returning the underlying stream.
+ ///
+ /// Note that any leftover data in the internal buffer is lost.
+ /// If you additionally want access to the internal buffer use
+ /// [`into_inner_with_chunk`].
+ ///
+ /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
+ pub fn into_inner(self) -> S {
+ self.inner
+ }
+}
+
+impl<S, B, E> AsyncRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ if buf.remaining() == 0 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let inner_buf = match self.as_mut().poll_fill_buf(cx) {
+ Poll::Ready(Ok(buf)) => buf,
+ Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+ Poll::Pending => return Poll::Pending,
+ };
+ let len = std::cmp::min(inner_buf.len(), buf.remaining());
+ buf.put_slice(&inner_buf[..len]);
+
+ self.consume(len);
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<S, B, E> AsyncBufRead for StreamReader<S, B>
+where
+ S: Stream<Item = Result<B, E>>,
+ B: Buf,
+ E: Into<std::io::Error>,
+{
+ fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
+ loop {
+ if self.as_mut().has_chunk() {
+ // This unwrap is very sad, but it can't be avoided.
+ let buf = self.project().chunk.as_ref().unwrap().chunk();
+ return Poll::Ready(Ok(buf));
+ } else {
+ match self.as_mut().project().inner.poll_next(cx) {
+ Poll::Ready(Some(Ok(chunk))) => {
+ // Go around the loop in case the chunk is empty.
+ *self.as_mut().project().chunk = Some(chunk);
+ }
+ Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
+ Poll::Ready(None) => return Poll::Ready(Ok(&[])),
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ }
+ }
+ fn consume(self: Pin<&mut Self>, amt: usize) {
+ if amt > 0 {
+ self.project()
+ .chunk
+ .as_mut()
+ .expect("No chunk present")
+ .advance(amt);
+ }
+ }
+}
+
+// The code below is a manual expansion of the code that pin-project-lite would
+// generate. This is done because pin-project-lite fails by hitting the recusion
+// limit on this struct. (Every line of documentation is handled recursively by
+// the macro.)
+
+impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
+
+struct StreamReaderProject<'a, S, B> {
+ inner: Pin<&'a mut S>,
+ chunk: &'a mut Option<B>,
+}
+
+impl<S, B> StreamReader<S, B> {
+ #[inline]
+ fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
+ // SAFETY: We define that only `inner` should be pinned when `Self` is
+ // and have an appropriate `impl Unpin` for this.
+ let me = unsafe { Pin::into_inner_unchecked(self) };
+ StreamReaderProject {
+ inner: unsafe { Pin::new_unchecked(&mut me.inner) },
+ chunk: &mut me.chunk,
+ }
+ }
+}
diff --git a/src/io/sync_bridge.rs b/src/io/sync_bridge.rs
new file mode 100644
index 0000000..f87bfbb
--- /dev/null
+++ b/src/io/sync_bridge.rs
@@ -0,0 +1,143 @@
+use std::io::{BufRead, Read, Write};
+use tokio::io::{
+ AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
+};
+
+/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+#[derive(Debug)]
+pub struct SyncIoBridge<T> {
+ src: T,
+ rt: tokio::runtime::Handle,
+}
+
+impl<T: AsyncBufRead + Unpin> BufRead for SyncIoBridge<T> {
+ fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
+ let src = &mut self.src;
+ self.rt.block_on(AsyncBufReadExt::fill_buf(src))
+ }
+
+ fn consume(&mut self, amt: usize) {
+ let src = &mut self.src;
+ AsyncBufReadExt::consume(src, amt)
+ }
+
+ fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt
+ .block_on(AsyncBufReadExt::read_until(src, byte, buf))
+ }
+ fn read_line(&mut self, buf: &mut String) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(AsyncBufReadExt::read_line(src, buf))
+ }
+}
+
+impl<T: AsyncRead + Unpin> Read for SyncIoBridge<T> {
+ fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(AsyncReadExt::read(src, buf))
+ }
+
+ fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.read_to_end(buf))
+ }
+
+ fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.read_to_string(buf))
+ }
+
+ fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
+ let src = &mut self.src;
+ // The AsyncRead trait returns the count, synchronous doesn't.
+ let _n = self.rt.block_on(src.read_exact(buf))?;
+ Ok(())
+ }
+}
+
+impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> {
+ fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write(buf))
+ }
+
+ fn flush(&mut self) -> std::io::Result<()> {
+ let src = &mut self.src;
+ self.rt.block_on(src.flush())
+ }
+
+ fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write_all(buf))
+ }
+
+ fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
+ let src = &mut self.src;
+ self.rt.block_on(src.write_vectored(bufs))
+ }
+}
+
+// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time
+// of this writing still unstable, we expose this as part of a standalone method.
+impl<T: AsyncWrite> SyncIoBridge<T> {
+ /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes.
+ ///
+ /// See [`tokio::io::AsyncWrite::is_write_vectored`].
+ pub fn is_write_vectored(&self) -> bool {
+ self.src.is_write_vectored()
+ }
+}
+
+impl<T: AsyncWrite + Unpin> SyncIoBridge<T> {
+ /// Shutdown this writer. This method provides a way to call the [`AsyncWriteExt::shutdown`]
+ /// function of the inner [`tokio::io::AsyncWrite`] instance.
+ ///
+ /// # Errors
+ ///
+ /// This method returns the same errors as [`AsyncWriteExt::shutdown`].
+ ///
+ /// [`AsyncWriteExt::shutdown`]: tokio::io::AsyncWriteExt::shutdown
+ pub fn shutdown(&mut self) -> std::io::Result<()> {
+ let src = &mut self.src;
+ self.rt.block_on(src.shutdown())
+ }
+}
+
+impl<T: Unpin> SyncIoBridge<T> {
+ /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+ /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+ ///
+ /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`].
+ /// It is hence OK to move this struct into a separate thread outside the runtime, as created
+ /// by e.g. [`tokio::task::spawn_blocking`].
+ ///
+ /// Stated even more strongly: to make use of this bridge, you *must* move
+ /// it into a separate thread outside the runtime. The synchronous I/O will use the
+ /// underlying handle to block on the backing asynchronous source, via
+ /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that
+ /// function, an attempt to `block_on` from an asynchronous execution context
+ /// will panic.
+ ///
+ /// # Wrapping `!Unpin` types
+ ///
+ /// Use e.g. `SyncIoBridge::new(Box::pin(src))`.
+ ///
+ /// # Panics
+ ///
+ /// This will panic if called outside the context of a Tokio runtime.
+ #[track_caller]
+ pub fn new(src: T) -> Self {
+ Self::new_with_handle(src, tokio::runtime::Handle::current())
+ }
+
+ /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
+ /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
+ ///
+ /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may
+ /// be initially invoked outside of an asynchronous context.
+ pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self {
+ Self { src, rt }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..524fc47
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,205 @@
+#![allow(clippy::needless_doctest_main)]
+#![warn(
+ missing_debug_implementations,
+ missing_docs,
+ rust_2018_idioms,
+ unreachable_pub
+)]
+#![doc(test(
+ no_crate_inject,
+ attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
+))]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+
+//! Utilities for working with Tokio.
+//!
+//! This crate is not versioned in lockstep with the core
+//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's
+//! semantic versioning policy, especially with regard to breaking changes.
+//!
+//! [`tokio`]: https://docs.rs/tokio
+
+#[macro_use]
+mod cfg;
+
+mod loom;
+
+cfg_codec! {
+ pub mod codec;
+}
+
+cfg_net! {
+ #[cfg(not(target_arch = "wasm32"))]
+ pub mod udp;
+ pub mod net;
+}
+
+cfg_compat! {
+ pub mod compat;
+}
+
+cfg_io! {
+ pub mod io;
+}
+
+cfg_rt! {
+ pub mod context;
+ pub mod task;
+}
+
+cfg_time! {
+ pub mod time;
+}
+
+pub mod sync;
+
+pub mod either;
+
+#[cfg(any(feature = "io", feature = "codec"))]
+mod util {
+ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+
+ use bytes::{Buf, BufMut};
+ use futures_core::ready;
+ use std::io::{self, IoSlice};
+ use std::mem::MaybeUninit;
+ use std::pin::Pin;
+ use std::task::{Context, Poll};
+
+ /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
+ ///
+ /// [`BufMut`]: bytes::Buf
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use bytes::{Bytes, BytesMut};
+ /// use tokio_stream as stream;
+ /// use tokio::io::Result;
+ /// use tokio_util::io::{StreamReader, poll_read_buf};
+ /// use futures::future::poll_fn;
+ /// use std::pin::Pin;
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ ///
+ /// // Create a reader from an iterator. This particular reader will always be
+ /// // ready.
+ /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
+ ///
+ /// let mut buf = BytesMut::new();
+ /// let mut reads = 0;
+ ///
+ /// loop {
+ /// reads += 1;
+ /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
+ ///
+ /// if n == 0 {
+ /// break;
+ /// }
+ /// }
+ ///
+ /// // one or more reads might be necessary.
+ /// assert!(reads >= 1);
+ /// assert_eq!(&buf[..], &[0, 1, 2, 3]);
+ /// # Ok(())
+ /// # }
+ /// ```
+ #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+ pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ if !buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = {
+ let dst = buf.chunk_mut();
+
+ // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
+ // transparent wrapper around `[MaybeUninit<u8>]`.
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ let ptr = buf.filled().as_ptr();
+ ready!(io.poll_read(cx, &mut buf)?);
+
+ // Ensure the pointer does not change from under us
+ assert_eq!(ptr, buf.filled().as_ptr());
+ buf.filled().len()
+ };
+
+ // Safety: This is guaranteed to be the number of initialized (and read)
+ // bytes due to the invariants provided by `ReadBuf::filled`.
+ unsafe {
+ buf.advance_mut(n);
+ }
+
+ Poll::Ready(Ok(n))
+ }
+
+ /// Try to write data from an implementer of the [`Buf`] trait to an
+ /// [`AsyncWrite`], advancing the buffer's internal cursor.
+ ///
+ /// This function will use [vectored writes] when the [`AsyncWrite`] supports
+ /// vectored writes.
+ ///
+ /// # Examples
+ ///
+ /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
+ /// [`Buf`]:
+ ///
+ /// ```no_run
+ /// use tokio_util::io::poll_write_buf;
+ /// use tokio::io;
+ /// use tokio::fs::File;
+ ///
+ /// use bytes::Buf;
+ /// use std::io::Cursor;
+ /// use std::pin::Pin;
+ /// use futures::future::poll_fn;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let mut file = File::create("foo.txt").await?;
+ /// let mut buf = Cursor::new(b"data to write");
+ ///
+ /// // Loop until the entire contents of the buffer are written to
+ /// // the file.
+ /// while buf.has_remaining() {
+ /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
+ /// }
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// [`Buf`]: bytes::Buf
+ /// [`AsyncWrite`]: tokio::io::AsyncWrite
+ /// [`File`]: tokio::fs::File
+ /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
+ #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
+ pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
+ io: Pin<&mut T>,
+ cx: &mut Context<'_>,
+ buf: &mut B,
+ ) -> Poll<io::Result<usize>> {
+ const MAX_BUFS: usize = 64;
+
+ if !buf.has_remaining() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = if io.is_write_vectored() {
+ let mut slices = [IoSlice::new(&[]); MAX_BUFS];
+ let cnt = buf.chunks_vectored(&mut slices);
+ ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
+ } else {
+ ready!(io.poll_write(cx, buf.chunk()))?
+ };
+
+ buf.advance(n);
+
+ Poll::Ready(Ok(n))
+ }
+}
diff --git a/src/loom.rs b/src/loom.rs
new file mode 100644
index 0000000..dd03fea
--- /dev/null
+++ b/src/loom.rs
@@ -0,0 +1 @@
+pub(crate) use std::sync;
diff --git a/src/net/mod.rs b/src/net/mod.rs
new file mode 100644
index 0000000..4817e10
--- /dev/null
+++ b/src/net/mod.rs
@@ -0,0 +1,97 @@
+//! TCP/UDP/Unix helpers for tokio.
+
+use crate::either::Either;
+use std::future::Future;
+use std::io::Result;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+#[cfg(unix)]
+pub mod unix;
+
+/// A trait for a listener: `TcpListener` and `UnixListener`.
+pub trait Listener {
+ /// The stream's type of this listener.
+ type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite;
+ /// The socket address type of this listener.
+ type Addr;
+
+ /// Polls to accept a new incoming connection to this listener.
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>>;
+
+ /// Accepts a new incoming connection from this listener.
+ fn accept(&mut self) -> ListenerAcceptFut<'_, Self>
+ where
+ Self: Sized,
+ {
+ ListenerAcceptFut { listener: self }
+ }
+
+ /// Returns the local address that this listener is bound to.
+ fn local_addr(&self) -> Result<Self::Addr>;
+}
+
+impl Listener for tokio::net::TcpListener {
+ type Io = tokio::net::TcpStream;
+ type Addr = std::net::SocketAddr;
+
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
+ Self::poll_accept(self, cx)
+ }
+
+ fn local_addr(&self) -> Result<Self::Addr> {
+ self.local_addr().map(Into::into)
+ }
+}
+
+/// Future for accepting a new connection from a listener.
+#[derive(Debug)]
+#[must_use = "futures do nothing unless you `.await` or poll them"]
+pub struct ListenerAcceptFut<'a, L> {
+ listener: &'a mut L,
+}
+
+impl<'a, L> Future for ListenerAcceptFut<'a, L>
+where
+ L: Listener,
+{
+ type Output = Result<(L::Io, L::Addr)>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ self.listener.poll_accept(cx)
+ }
+}
+
+impl<L, R> Either<L, R>
+where
+ L: Listener,
+ R: Listener,
+{
+ /// Accepts a new incoming connection from this listener.
+ pub async fn accept(&mut self) -> Result<Either<(L::Io, L::Addr), (R::Io, R::Addr)>> {
+ match self {
+ Either::Left(listener) => {
+ let (stream, addr) = listener.accept().await?;
+ Ok(Either::Left((stream, addr)))
+ }
+ Either::Right(listener) => {
+ let (stream, addr) = listener.accept().await?;
+ Ok(Either::Right((stream, addr)))
+ }
+ }
+ }
+
+ /// Returns the local address that this listener is bound to.
+ pub fn local_addr(&self) -> Result<Either<L::Addr, R::Addr>> {
+ match self {
+ Either::Left(listener) => {
+ let addr = listener.local_addr()?;
+ Ok(Either::Left(addr))
+ }
+ Either::Right(listener) => {
+ let addr = listener.local_addr()?;
+ Ok(Either::Right(addr))
+ }
+ }
+ }
+}
diff --git a/src/net/unix/mod.rs b/src/net/unix/mod.rs
new file mode 100644
index 0000000..0b522c9
--- /dev/null
+++ b/src/net/unix/mod.rs
@@ -0,0 +1,18 @@
+//! Unix domain socket helpers.
+
+use super::Listener;
+use std::io::Result;
+use std::task::{Context, Poll};
+
+impl Listener for tokio::net::UnixListener {
+ type Io = tokio::net::UnixStream;
+ type Addr = tokio::net::unix::SocketAddr;
+
+ fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
+ Self::poll_accept(self, cx)
+ }
+
+ fn local_addr(&self) -> Result<Self::Addr> {
+ self.local_addr().map(Into::into)
+ }
+}
diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs
new file mode 100644
index 0000000..c44be69
--- /dev/null
+++ b/src/sync/cancellation_token.rs
@@ -0,0 +1,324 @@
+//! An asynchronously awaitable `CancellationToken`.
+//! The token allows to signal a cancellation request to one or more tasks.
+pub(crate) mod guard;
+mod tree_node;
+
+use crate::loom::sync::Arc;
+use core::future::Future;
+use core::pin::Pin;
+use core::task::{Context, Poll};
+
+use guard::DropGuard;
+use pin_project_lite::pin_project;
+
+/// A token which can be used to signal a cancellation request to one or more
+/// tasks.
+///
+/// Tasks can call [`CancellationToken::cancelled()`] in order to
+/// obtain a Future which will be resolved when cancellation is requested.
+///
+/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
+///
+/// # Examples
+///
+/// ```no_run
+/// use tokio::select;
+/// use tokio_util::sync::CancellationToken;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let token = CancellationToken::new();
+/// let cloned_token = token.clone();
+///
+/// let join_handle = tokio::spawn(async move {
+/// // Wait for either cancellation or a very long time
+/// select! {
+/// _ = cloned_token.cancelled() => {
+/// // The token was cancelled
+/// 5
+/// }
+/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
+/// 99
+/// }
+/// }
+/// });
+///
+/// tokio::spawn(async move {
+/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+/// token.cancel();
+/// });
+///
+/// assert_eq!(5, join_handle.await.unwrap());
+/// }
+/// ```
+pub struct CancellationToken {
+ inner: Arc<tree_node::TreeNode>,
+}
+
+impl std::panic::UnwindSafe for CancellationToken {}
+impl std::panic::RefUnwindSafe for CancellationToken {}
+
+pin_project! {
+ /// A Future that is resolved once the corresponding [`CancellationToken`]
+ /// is cancelled.
+ #[must_use = "futures do nothing unless polled"]
+ pub struct WaitForCancellationFuture<'a> {
+ cancellation_token: &'a CancellationToken,
+ #[pin]
+ future: tokio::sync::futures::Notified<'a>,
+ }
+}
+
+pin_project! {
+ /// A Future that is resolved once the corresponding [`CancellationToken`]
+ /// is cancelled.
+ ///
+ /// This is the counterpart to [`WaitForCancellationFuture`] that takes
+ /// [`CancellationToken`] by value instead of using a reference.
+ #[must_use = "futures do nothing unless polled"]
+ pub struct WaitForCancellationFutureOwned {
+ // Since `future` is the first field, it is dropped before the
+ // cancellation_token field. This ensures that the reference inside the
+ // `Notified` remains valid.
+ #[pin]
+ future: tokio::sync::futures::Notified<'static>,
+ cancellation_token: CancellationToken,
+ }
+}
+
+// ===== impl CancellationToken =====
+
+impl core::fmt::Debug for CancellationToken {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ f.debug_struct("CancellationToken")
+ .field("is_cancelled", &self.is_cancelled())
+ .finish()
+ }
+}
+
+impl Clone for CancellationToken {
+ fn clone(&self) -> Self {
+ tree_node::increase_handle_refcount(&self.inner);
+ CancellationToken {
+ inner: self.inner.clone(),
+ }
+ }
+}
+
+impl Drop for CancellationToken {
+ fn drop(&mut self) {
+ tree_node::decrease_handle_refcount(&self.inner);
+ }
+}
+
+impl Default for CancellationToken {
+ fn default() -> CancellationToken {
+ CancellationToken::new()
+ }
+}
+
+impl CancellationToken {
+ /// Creates a new CancellationToken in the non-cancelled state.
+ pub fn new() -> CancellationToken {
+ CancellationToken {
+ inner: Arc::new(tree_node::TreeNode::new()),
+ }
+ }
+
+ /// Creates a `CancellationToken` which will get cancelled whenever the
+ /// current token gets cancelled.
+ ///
+ /// If the current token is already cancelled, the child token will get
+ /// returned in cancelled state.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::select;
+ /// use tokio_util::sync::CancellationToken;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let token = CancellationToken::new();
+ /// let child_token = token.child_token();
+ ///
+ /// let join_handle = tokio::spawn(async move {
+ /// // Wait for either cancellation or a very long time
+ /// select! {
+ /// _ = child_token.cancelled() => {
+ /// // The token was cancelled
+ /// 5
+ /// }
+ /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
+ /// 99
+ /// }
+ /// }
+ /// });
+ ///
+ /// tokio::spawn(async move {
+ /// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+ /// token.cancel();
+ /// });
+ ///
+ /// assert_eq!(5, join_handle.await.unwrap());
+ /// }
+ /// ```
+ pub fn child_token(&self) -> CancellationToken {
+ CancellationToken {
+ inner: tree_node::child_node(&self.inner),
+ }
+ }
+
+ /// Cancel the [`CancellationToken`] and all child tokens which had been
+ /// derived from it.
+ ///
+ /// This will wake up all tasks which are waiting for cancellation.
+ ///
+ /// Be aware that cancellation is not an atomic operation. It is possible
+ /// for another thread running in parallel with a call to `cancel` to first
+ /// receive `true` from `is_cancelled` on one child node, and then receive
+ /// `false` from `is_cancelled` on another child node. However, once the
+ /// call to `cancel` returns, all child nodes have been fully cancelled.
+ pub fn cancel(&self) {
+ tree_node::cancel(&self.inner);
+ }
+
+ /// Returns `true` if the `CancellationToken` is cancelled.
+ pub fn is_cancelled(&self) -> bool {
+ tree_node::is_cancelled(&self.inner)
+ }
+
+ /// Returns a `Future` that gets fulfilled when cancellation is requested.
+ ///
+ /// The future will complete immediately if the token is already cancelled
+ /// when this method is called.
+ ///
+ /// # Cancel safety
+ ///
+ /// This method is cancel safe.
+ pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
+ WaitForCancellationFuture {
+ cancellation_token: self,
+ future: self.inner.notified(),
+ }
+ }
+
+ /// Returns a `Future` that gets fulfilled when cancellation is requested.
+ ///
+ /// The future will complete immediately if the token is already cancelled
+ /// when this method is called.
+ ///
+ /// The function takes self by value and returns a future that owns the
+ /// token.
+ ///
+ /// # Cancel safety
+ ///
+ /// This method is cancel safe.
+ pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
+ WaitForCancellationFutureOwned::new(self)
+ }
+
+ /// Creates a `DropGuard` for this token.
+ ///
+ /// Returned guard will cancel this token (and all its children) on drop
+ /// unless disarmed.
+ pub fn drop_guard(self) -> DropGuard {
+ DropGuard { inner: Some(self) }
+ }
+}
+
+// ===== impl WaitForCancellationFuture =====
+
+impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ f.debug_struct("WaitForCancellationFuture").finish()
+ }
+}
+
+impl<'a> Future for WaitForCancellationFuture<'a> {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+ let mut this = self.project();
+ loop {
+ if this.cancellation_token.is_cancelled() {
+ return Poll::Ready(());
+ }
+
+ // No wakeups can be lost here because there is always a call to
+ // `is_cancelled` between the creation of the future and the call to
+ // `poll`, and the code that sets the cancelled flag does so before
+ // waking the `Notified`.
+ if this.future.as_mut().poll(cx).is_pending() {
+ return Poll::Pending;
+ }
+
+ this.future.set(this.cancellation_token.inner.notified());
+ }
+ }
+}
+
+// ===== impl WaitForCancellationFutureOwned =====
+
+impl core::fmt::Debug for WaitForCancellationFutureOwned {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ f.debug_struct("WaitForCancellationFutureOwned").finish()
+ }
+}
+
+impl WaitForCancellationFutureOwned {
+ fn new(cancellation_token: CancellationToken) -> Self {
+ WaitForCancellationFutureOwned {
+ // cancellation_token holds a heap allocation and is guaranteed to have a
+ // stable deref, thus it would be ok to move the cancellation_token while
+ // the future holds a reference to it.
+ //
+ // # Safety
+ //
+ // cancellation_token is dropped after future due to the field ordering.
+ future: unsafe { Self::new_future(&cancellation_token) },
+ cancellation_token,
+ }
+ }
+
+ /// # Safety
+ /// The returned future must be destroyed before the cancellation token is
+ /// destroyed.
+ unsafe fn new_future(
+ cancellation_token: &CancellationToken,
+ ) -> tokio::sync::futures::Notified<'static> {
+ let inner_ptr = Arc::as_ptr(&cancellation_token.inner);
+ // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains
+ // valid until the strong count of the Arc drops to zero, and the caller
+ // guarantees that they will drop the future before that happens.
+ (*inner_ptr).notified()
+ }
+}
+
+impl Future for WaitForCancellationFutureOwned {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
+ let mut this = self.project();
+
+ loop {
+ if this.cancellation_token.is_cancelled() {
+ return Poll::Ready(());
+ }
+
+ // No wakeups can be lost here because there is always a call to
+ // `is_cancelled` between the creation of the future and the call to
+ // `poll`, and the code that sets the cancelled flag does so before
+ // waking the `Notified`.
+ if this.future.as_mut().poll(cx).is_pending() {
+ return Poll::Pending;
+ }
+
+ // # Safety
+ //
+ // cancellation_token is dropped after future due to the field ordering.
+ this.future
+ .set(unsafe { Self::new_future(this.cancellation_token) });
+ }
+ }
+}
diff --git a/src/sync/cancellation_token/guard.rs b/src/sync/cancellation_token/guard.rs
new file mode 100644
index 0000000..54ed7ea
--- /dev/null
+++ b/src/sync/cancellation_token/guard.rs
@@ -0,0 +1,27 @@
+use crate::sync::CancellationToken;
+
+/// A wrapper for cancellation token which automatically cancels
+/// it on drop. It is created using `drop_guard` method on the `CancellationToken`.
+#[derive(Debug)]
+pub struct DropGuard {
+ pub(super) inner: Option<CancellationToken>,
+}
+
+impl DropGuard {
+ /// Returns stored cancellation token and removes this drop guard instance
+ /// (i.e. it will no longer cancel token). Other guards for this token
+ /// are not affected.
+ pub fn disarm(mut self) -> CancellationToken {
+ self.inner
+ .take()
+ .expect("`inner` can be only None in a destructor")
+ }
+}
+
+impl Drop for DropGuard {
+ fn drop(&mut self) {
+ if let Some(inner) = &self.inner {
+ inner.cancel();
+ }
+ }
+}
diff --git a/src/sync/cancellation_token/tree_node.rs b/src/sync/cancellation_token/tree_node.rs
new file mode 100644
index 0000000..8f97dee
--- /dev/null
+++ b/src/sync/cancellation_token/tree_node.rs
@@ -0,0 +1,373 @@
+//! This mod provides the logic for the inner tree structure of the CancellationToken.
+//!
+//! CancellationTokens are only light handles with references to TreeNode.
+//! All the logic is actually implemented in the TreeNode.
+//!
+//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of
+//! children.
+//!
+//! A TreeNode can receive the request to perform a cancellation through a CancellationToken.
+//! This cancellation request will cancel the node and all of its descendants.
+//!
+//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no
+//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the
+//! tree as small as possible.
+//!
+//! # Invariants
+//!
+//! Those invariants shall be true at any time.
+//!
+//! 1. A node that has no parents and no handles can no longer be cancelled.
+//! This is important during both cancellation and refcounting.
+//!
+//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A.
+//! This is important for deadlock safety, as it is used for lock order.
+//! Node B can only become the child of node A in two ways:
+//! - being created with `child_node()`, in which case it is trivially true that
+//! node A already existed when node B was created
+//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()`
+//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C
+//! was younger than A, therefore B is also younger than A.
+//!
+//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of
+//! node A. It is important to always restore that invariant before dropping the lock of a node.
+//!
+//! # Deadlock safety
+//!
+//! We always lock in the order of creation time. We can prove this through invariant #2.
+//! Specifically, through invariant #2, we know that we always have to lock a parent
+//! before its child.
+//!
+use crate::loom::sync::{Arc, Mutex, MutexGuard};
+
+/// A node of the cancellation tree structure
+///
+/// The actual data it holds is wrapped inside a mutex for synchronization.
+pub(crate) struct TreeNode {
+ inner: Mutex<Inner>,
+ waker: tokio::sync::Notify,
+}
+impl TreeNode {
+ pub(crate) fn new() -> Self {
+ Self {
+ inner: Mutex::new(Inner {
+ parent: None,
+ parent_idx: 0,
+ children: vec![],
+ is_cancelled: false,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ }
+ }
+
+ pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> {
+ self.waker.notified()
+ }
+}
+
+/// The data contained inside a TreeNode.
+///
+/// This struct exists so that the data of the node can be wrapped
+/// in a Mutex.
+struct Inner {
+ parent: Option<Arc<TreeNode>>,
+ parent_idx: usize,
+ children: Vec<Arc<TreeNode>>,
+ is_cancelled: bool,
+ num_handles: usize,
+}
+
+/// Returns whether or not the node is cancelled
+pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool {
+ node.inner.lock().unwrap().is_cancelled
+}
+
+/// Creates a child node
+pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> {
+ let mut locked_parent = parent.inner.lock().unwrap();
+
+ // Do not register as child if we are already cancelled.
+ // Cancelled trees can never be uncancelled and therefore
+ // need no connection to parents or children any more.
+ if locked_parent.is_cancelled {
+ return Arc::new(TreeNode {
+ inner: Mutex::new(Inner {
+ parent: None,
+ parent_idx: 0,
+ children: vec![],
+ is_cancelled: true,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ });
+ }
+
+ let child = Arc::new(TreeNode {
+ inner: Mutex::new(Inner {
+ parent: Some(parent.clone()),
+ parent_idx: locked_parent.children.len(),
+ children: vec![],
+ is_cancelled: false,
+ num_handles: 1,
+ }),
+ waker: tokio::sync::Notify::new(),
+ });
+
+ locked_parent.children.push(child.clone());
+
+ child
+}
+
+/// Disconnects the given parent from all of its children.
+///
+/// Takes a reference to [Inner] to make sure the parent is already locked.
+fn disconnect_children(node: &mut Inner) {
+ for child in std::mem::take(&mut node.children) {
+ let mut locked_child = child.inner.lock().unwrap();
+ locked_child.parent_idx = 0;
+ locked_child.parent = None;
+ }
+}
+
+/// Figures out the parent of the node and locks the node and its parent atomically.
+///
+/// The basic principle of preventing deadlocks in the tree is
+/// that we always lock the parent first, and then the child.
+/// For more info look at *deadlock safety* and *invariant #2*.
+///
+/// Sadly, it's impossible to figure out the parent of a node without
+/// locking it. To then achieve locking order consistency, the node
+/// has to be unlocked before the parent gets locked.
+/// This leaves a small window where we already assume that we know the parent,
+/// but neither the parent nor the node is locked. Therefore, the parent could change.
+///
+/// To prevent that this problem leaks into the rest of the code, it is abstracted
+/// in this function.
+///
+/// The locked child and optionally its locked parent, if a parent exists, get passed
+/// to the `func` argument via (node, None) or (node, Some(parent)).
+fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret
+where
+ F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret,
+{
+ let mut potential_parent = {
+ let locked_node = node.inner.lock().unwrap();
+ match locked_node.parent.clone() {
+ Some(parent) => parent,
+ // If we locked the node and its parent is `None`, we are in a valid state
+ // and can return.
+ None => return func(locked_node, None),
+ }
+ };
+
+ loop {
+ // Deadlock safety:
+ //
+ // Due to invariant #2, we know that we have to lock the parent first, and then the child.
+ // This is true even if the potential_parent is no longer the current parent or even its
+ // sibling, as the invariant still holds.
+ let locked_parent = potential_parent.inner.lock().unwrap();
+ let locked_node = node.inner.lock().unwrap();
+
+ let actual_parent = match locked_node.parent.clone() {
+ Some(parent) => parent,
+ // If we locked the node and its parent is `None`, we are in a valid state
+ // and can return.
+ None => {
+ // Was the wrong parent, so unlock it before calling `func`
+ drop(locked_parent);
+ return func(locked_node, None);
+ }
+ };
+
+ // Loop until we managed to lock both the node and its parent
+ if Arc::ptr_eq(&actual_parent, &potential_parent) {
+ return func(locked_node, Some(locked_parent));
+ }
+
+ // Drop locked_parent before reassigning to potential_parent,
+ // as potential_parent is borrowed in it
+ drop(locked_node);
+ drop(locked_parent);
+
+ potential_parent = actual_parent;
+ }
+}
+
+/// Moves all children from `node` to `parent`.
+///
+/// `parent` MUST have been a parent of the node when they both got locked,
+/// otherwise there is a potential for a deadlock as invariant #2 would be violated.
+///
+/// To acquire the locks for node and parent, use [with_locked_node_and_parent].
+fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) {
+ // Pre-allocate in the parent, for performance
+ parent.children.reserve(node.children.len());
+
+ for child in std::mem::take(&mut node.children) {
+ {
+ let mut child_locked = child.inner.lock().unwrap();
+ child_locked.parent = node.parent.clone();
+ child_locked.parent_idx = parent.children.len();
+ }
+ parent.children.push(child);
+ }
+}
+
+/// Removes a child from the parent.
+///
+/// `parent` MUST be the parent of `node`.
+/// To acquire the locks for node and parent, use [with_locked_node_and_parent].
+fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) {
+ // Query the position from where to remove a node
+ let pos = node.parent_idx;
+ node.parent = None;
+ node.parent_idx = 0;
+
+ // Unlock node, so that only one child at a time is locked.
+ // Otherwise we would violate the lock order (see 'deadlock safety') as we
+ // don't know the creation order of the child nodes
+ drop(node);
+
+ // If `node` is the last element in the list, we don't need any swapping
+ if parent.children.len() == pos + 1 {
+ parent.children.pop().unwrap();
+ } else {
+ // If `node` is not the last element in the list, we need to
+ // replace it with the last element
+ let replacement_child = parent.children.pop().unwrap();
+ replacement_child.inner.lock().unwrap().parent_idx = pos;
+ parent.children[pos] = replacement_child;
+ }
+
+ let len = parent.children.len();
+ if 4 * len <= parent.children.capacity() {
+ // equal to:
+ // parent.children.shrink_to(2 * len);
+ // but shrink_to was not yet stabilized in our minimal compatible version
+ let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len));
+ parent.children.extend(old_children);
+ }
+}
+
+/// Increases the reference count of handles.
+pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) {
+ let mut locked_node = node.inner.lock().unwrap();
+
+ // Once no handles are left over, the node gets detached from the tree.
+ // There should never be a new handle once all handles are dropped.
+ assert!(locked_node.num_handles > 0);
+
+ locked_node.num_handles += 1;
+}
+
+/// Decreases the reference count of handles.
+///
+/// Once no handle is left, we can remove the node from the
+/// tree and connect its parent directly to its children.
+pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) {
+ let num_handles = {
+ let mut locked_node = node.inner.lock().unwrap();
+ locked_node.num_handles -= 1;
+ locked_node.num_handles
+ };
+
+ if num_handles == 0 {
+ with_locked_node_and_parent(node, |mut node, parent| {
+ // Remove the node from the tree
+ match parent {
+ Some(mut parent) => {
+ // As we want to remove ourselves from the tree,
+ // we have to move the children to the parent, so that
+ // they still receive the cancellation event without us.
+ // Moving them does not violate invariant #1.
+ move_children_to_parent(&mut node, &mut parent);
+
+ // Remove the node from the parent
+ remove_child(&mut parent, node);
+ }
+ None => {
+ // Due to invariant #1, we can assume that our
+ // children can no longer be cancelled through us.
+ // (as we now have neither a parent nor handles)
+ // Therefore we can disconnect them.
+ disconnect_children(&mut node);
+ }
+ }
+ });
+ }
+}
+
+/// Cancels a node and its children.
+pub(crate) fn cancel(node: &Arc<TreeNode>) {
+ let mut locked_node = node.inner.lock().unwrap();
+
+ if locked_node.is_cancelled {
+ return;
+ }
+
+ // One by one, adopt grandchildren and then cancel and detach the child
+ while let Some(child) = locked_node.children.pop() {
+ // This can't deadlock because the mutex we are already
+ // holding is the parent of child.
+ let mut locked_child = child.inner.lock().unwrap();
+
+ // Detach the child from node
+ // No need to modify node.children, as the child already got removed with `.pop`
+ locked_child.parent = None;
+ locked_child.parent_idx = 0;
+
+ // If child is already cancelled, detaching is enough
+ if locked_child.is_cancelled {
+ continue;
+ }
+
+ // Cancel or adopt grandchildren
+ while let Some(grandchild) = locked_child.children.pop() {
+ // This can't deadlock because the two mutexes we are already
+ // holding is the parent and grandparent of grandchild.
+ let mut locked_grandchild = grandchild.inner.lock().unwrap();
+
+ // Detach the grandchild
+ locked_grandchild.parent = None;
+ locked_grandchild.parent_idx = 0;
+
+ // If grandchild is already cancelled, detaching is enough
+ if locked_grandchild.is_cancelled {
+ continue;
+ }
+
+ // For performance reasons, only adopt grandchildren that have children.
+ // Otherwise, just cancel them right away, no need for another iteration.
+ if locked_grandchild.children.is_empty() {
+ // Cancel the grandchild
+ locked_grandchild.is_cancelled = true;
+ locked_grandchild.children = Vec::new();
+ drop(locked_grandchild);
+ grandchild.waker.notify_waiters();
+ } else {
+ // Otherwise, adopt grandchild
+ locked_grandchild.parent = Some(node.clone());
+ locked_grandchild.parent_idx = locked_node.children.len();
+ drop(locked_grandchild);
+ locked_node.children.push(grandchild);
+ }
+ }
+
+ // Cancel the child
+ locked_child.is_cancelled = true;
+ locked_child.children = Vec::new();
+ drop(locked_child);
+ child.waker.notify_waiters();
+
+ // Now the child is cancelled and detached and all its children are adopted.
+ // Just continue until all (including adopted) children are cancelled and detached.
+ }
+
+ // Cancel the node itself.
+ locked_node.is_cancelled = true;
+ locked_node.children = Vec::new();
+ drop(locked_node);
+ node.waker.notify_waiters();
+}
diff --git a/src/sync/mod.rs b/src/sync/mod.rs
new file mode 100644
index 0000000..1ae1b7e
--- /dev/null
+++ b/src/sync/mod.rs
@@ -0,0 +1,15 @@
+//! Synchronization primitives
+
+mod cancellation_token;
+pub use cancellation_token::{
+ guard::DropGuard, CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned,
+};
+
+mod mpsc;
+pub use mpsc::{PollSendError, PollSender};
+
+mod poll_semaphore;
+pub use poll_semaphore::PollSemaphore;
+
+mod reusable_box;
+pub use reusable_box::ReusableBoxFuture;
diff --git a/src/sync/mpsc.rs b/src/sync/mpsc.rs
new file mode 100644
index 0000000..55ed5c4
--- /dev/null
+++ b/src/sync/mpsc.rs
@@ -0,0 +1,284 @@
+use futures_sink::Sink;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{fmt, mem};
+use tokio::sync::mpsc::OwnedPermit;
+use tokio::sync::mpsc::Sender;
+
+use super::ReusableBoxFuture;
+
+/// Error returned by the `PollSender` when the channel is closed.
+#[derive(Debug)]
+pub struct PollSendError<T>(Option<T>);
+
+impl<T> PollSendError<T> {
+ /// Consumes the stored value, if any.
+ ///
+ /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
+ /// that the caller attempted to send. Otherwise, it will be `None`.
+ pub fn into_inner(self) -> Option<T> {
+ self.0
+ }
+}
+
+impl<T> fmt::Display for PollSendError<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+}
+
+impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
+
+#[derive(Debug)]
+enum State<T> {
+ Idle(Sender<T>),
+ Acquiring,
+ ReadyToSend(OwnedPermit<T>),
+ Closed,
+}
+
+/// A wrapper around [`mpsc::Sender`] that can be polled.
+///
+/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
+#[derive(Debug)]
+pub struct PollSender<T> {
+ sender: Option<Sender<T>>,
+ state: State<T>,
+ acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>,
+}
+
+// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
+// there's capacity for a send to complete.
+//
+// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
+// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
+async fn make_acquire_future<T>(
+ data: Option<Sender<T>>,
+) -> Result<OwnedPermit<T>, PollSendError<T>> {
+ match data {
+ Some(sender) => sender
+ .reserve_owned()
+ .await
+ .map_err(|_| PollSendError(None)),
+ None => unreachable!("this future should not be pollable in this state"),
+ }
+}
+
+impl<T: Send + 'static> PollSender<T> {
+ /// Creates a new `PollSender`.
+ pub fn new(sender: Sender<T>) -> Self {
+ Self {
+ sender: Some(sender.clone()),
+ state: State::Idle(sender),
+ acquire: ReusableBoxFuture::new(make_acquire_future(None)),
+ }
+ }
+
+ fn take_state(&mut self) -> State<T> {
+ mem::replace(&mut self.state, State::Closed)
+ }
+
+ /// Attempts to prepare the sender to receive a value.
+ ///
+ /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
+ /// `send_item`.
+ ///
+ /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
+ /// by reserving a slot in the channel for the item to be sent. If this method returns
+ /// `Poll::Pending`, the current task is registered to be notified (via
+ /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
+ loop {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(sender) => {
+ // Start trying to acquire a permit to reserve a slot for our send, and
+ // immediately loop back around to poll it the first time.
+ self.acquire.set(make_acquire_future(Some(sender)));
+ (None, State::Acquiring)
+ }
+ State::Acquiring => match self.acquire.poll(cx) {
+ // Channel has capacity.
+ Poll::Ready(Ok(permit)) => {
+ (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
+ }
+ // Channel is closed.
+ Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
+ // Channel doesn't have capacity yet, so we need to wait.
+ Poll::Pending => (Some(Poll::Pending), State::Acquiring),
+ },
+ // We're closed, either by choice or because the underlying sender was closed.
+ s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
+ // We're already ready to send an item.
+ s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
+ };
+
+ self.state = next_state;
+ if let Some(result) = result {
+ return result;
+ }
+ }
+ }
+
+ /// Sends an item to the channel.
+ ///
+ /// Before calling `send_item`, `poll_reserve` must be called with a successful return
+ /// value of `Poll::Ready(Ok(()))`.
+ ///
+ /// # Errors
+ ///
+ /// If the channel is closed, an error will be returned. This is a permanent state.
+ ///
+ /// # Panics
+ ///
+ /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
+ /// will panic.
+ #[track_caller]
+ pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
+ let (result, next_state) = match self.take_state() {
+ State::Idle(_) | State::Acquiring => {
+ panic!("`send_item` called without first calling `poll_reserve`")
+ }
+ // We have a permit to send our item, so go ahead, which gets us our sender back.
+ State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
+ // We're closed, either by choice or because the underlying sender was closed.
+ State::Closed => (Err(PollSendError(Some(value))), State::Closed),
+ };
+
+ // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
+ self.state = if self.sender.is_some() {
+ next_state
+ } else {
+ State::Closed
+ };
+ result
+ }
+
+ /// Checks whether this sender is been closed.
+ ///
+ /// The underlying channel that this sender was wrapping may still be open.
+ pub fn is_closed(&self) -> bool {
+ matches!(self.state, State::Closed) || self.sender.is_none()
+ }
+
+ /// Gets a reference to the `Sender` of the underlying channel.
+ ///
+ /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
+ /// was wrapping may still be open.
+ pub fn get_ref(&self) -> Option<&Sender<T>> {
+ self.sender.as_ref()
+ }
+
+ /// Closes this sender.
+ ///
+ /// No more messages will be able to be sent from this sender, but the underlying channel will
+ /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
+ ///
+ /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
+ /// to `send_item` in order to consume the reserved slot. After that, no further sends will be
+ /// possible. If you do not intend to send another item, you can release the reserved slot back
+ /// to the underlying sender by calling [`abort_send`].
+ ///
+ /// [`abort_send`]: crate::sync::PollSender::abort_send
+ /// [`Receiver`]: tokio::sync::mpsc::Receiver
+ pub fn close(&mut self) {
+ // Mark ourselves officially closed by dropping our main sender.
+ self.sender = None;
+
+ // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
+ // transition to the closed state. Otherwise, leave the existing permit in place for the
+ // caller if they want to complete the send.
+ match self.state {
+ State::Idle(_) => self.state = State::Closed,
+ State::Acquiring => {
+ self.acquire.set(make_acquire_future(None));
+ self.state = State::Closed;
+ }
+ _ => {}
+ }
+ }
+
+ /// Aborts the current in-progress send, if any.
+ ///
+ /// Returns `true` if a send was aborted. If the sender was closed prior to calling
+ /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
+ /// ready to attempt another send.
+ pub fn abort_send(&mut self) -> bool {
+ // We may have been closed in the meantime, after a call to `poll_reserve` already
+ // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
+ // closed state when we actually abort a send, rather than resetting ourselves back to idle.
+
+ let (result, next_state) = match self.take_state() {
+ // We're currently trying to reserve a slot to send into.
+ State::Acquiring => {
+ // Replacing the future drops the in-flight one.
+ self.acquire.set(make_acquire_future(None));
+
+ // If we haven't closed yet, we have to clone our stored sender since we have no way
+ // to get it back from the acquire future we just dropped.
+ let state = match self.sender.clone() {
+ Some(sender) => State::Idle(sender),
+ None => State::Closed,
+ };
+ (true, state)
+ }
+ // We got the permit. If we haven't closed yet, get the sender back.
+ State::ReadyToSend(permit) => {
+ let state = if self.sender.is_some() {
+ State::Idle(permit.release())
+ } else {
+ State::Closed
+ };
+ (true, state)
+ }
+ s => (false, s),
+ };
+
+ self.state = next_state;
+ result
+ }
+}
+
+impl<T> Clone for PollSender<T> {
+ /// Clones this `PollSender`.
+ ///
+ /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
+ fn clone(&self) -> PollSender<T> {
+ let (sender, state) = match self.sender.clone() {
+ Some(sender) => (Some(sender.clone()), State::Idle(sender)),
+ None => (None, State::Closed),
+ };
+
+ Self {
+ sender,
+ state,
+ // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
+ // compatible with the transitive bounds required by `Sender<T>`.
+ acquire: ReusableBoxFuture::new(async { unreachable!() }),
+ }
+ }
+}
+
+impl<T: Send + 'static> Sink<T> for PollSender<T> {
+ type Error = PollSendError<T>;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).poll_reserve(cx)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+ Pin::into_inner(self).send_item(item)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Pin::into_inner(self).close();
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/src/sync/poll_semaphore.rs b/src/sync/poll_semaphore.rs
new file mode 100644
index 0000000..6b44574
--- /dev/null
+++ b/src/sync/poll_semaphore.rs
@@ -0,0 +1,171 @@
+use futures_core::{ready, Stream};
+use std::fmt;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
+
+use super::ReusableBoxFuture;
+
+/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
+///
+/// [`Semaphore`]: tokio::sync::Semaphore
+pub struct PollSemaphore {
+ semaphore: Arc<Semaphore>,
+ permit_fut: Option<(
+ u32, // The number of permits requested.
+ ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
+ )>,
+}
+
+impl PollSemaphore {
+ /// Create a new `PollSemaphore`.
+ pub fn new(semaphore: Arc<Semaphore>) -> Self {
+ Self {
+ semaphore,
+ permit_fut: None,
+ }
+ }
+
+ /// Closes the semaphore.
+ pub fn close(&self) {
+ self.semaphore.close()
+ }
+
+ /// Obtain a clone of the inner semaphore.
+ pub fn clone_inner(&self) -> Arc<Semaphore> {
+ self.semaphore.clone()
+ }
+
+ /// Get back the inner semaphore.
+ pub fn into_inner(self) -> Arc<Semaphore> {
+ self.semaphore
+ }
+
+ /// Poll to acquire a permit from the semaphore.
+ ///
+ /// This can return the following values:
+ ///
+ /// - `Poll::Pending` if a permit is not currently available.
+ /// - `Poll::Ready(Some(permit))` if a permit was acquired.
+ /// - `Poll::Ready(None)` if the semaphore has been closed.
+ ///
+ /// When this method returns `Poll::Pending`, the current task is scheduled
+ /// to receive a wakeup when a permit becomes available, or when the
+ /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
+ /// the `Waker` from the `Context` passed to the most recent call is
+ /// scheduled to receive a wakeup.
+ pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ self.poll_acquire_many(cx, 1)
+ }
+
+ /// Poll to acquire many permits from the semaphore.
+ ///
+ /// This can return the following values:
+ ///
+ /// - `Poll::Pending` if a permit is not currently available.
+ /// - `Poll::Ready(Some(permit))` if a permit was acquired.
+ /// - `Poll::Ready(None)` if the semaphore has been closed.
+ ///
+ /// When this method returns `Poll::Pending`, the current task is scheduled
+ /// to receive a wakeup when the permits become available, or when the
+ /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
+ /// the `Waker` from the `Context` passed to the most recent call is
+ /// scheduled to receive a wakeup.
+ pub fn poll_acquire_many(
+ &mut self,
+ cx: &mut Context<'_>,
+ permits: u32,
+ ) -> Poll<Option<OwnedSemaphorePermit>> {
+ let permit_future = match self.permit_fut.as_mut() {
+ Some((prev_permits, fut)) if *prev_permits == permits => fut,
+ Some((old_permits, fut_box)) => {
+ // We're requesting a different number of permits, so replace the future
+ // and record the new amount.
+ let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ fut_box.set(fut);
+ *old_permits = permits;
+ fut_box
+ }
+ None => {
+ // avoid allocations completely if we can grab a permit immediately
+ match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
+ Ok(permit) => return Poll::Ready(Some(permit)),
+ Err(TryAcquireError::Closed) => return Poll::Ready(None),
+ Err(TryAcquireError::NoPermits) => {}
+ }
+
+ let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ &mut self
+ .permit_fut
+ .get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
+ .1
+ }
+ };
+
+ let result = ready!(permit_future.poll(cx));
+
+ // Assume we'll request the same amount of permits in a subsequent call.
+ let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ permit_future.set(next_fut);
+
+ match result {
+ Ok(permit) => Poll::Ready(Some(permit)),
+ Err(_closed) => {
+ self.permit_fut = None;
+ Poll::Ready(None)
+ }
+ }
+ }
+
+ /// Returns the current number of available permits.
+ ///
+ /// This is equivalent to the [`Semaphore::available_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
+ pub fn available_permits(&self) -> usize {
+ self.semaphore.available_permits()
+ }
+
+ /// Adds `n` new permits to the semaphore.
+ ///
+ /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
+ /// will panic if the limit is exceeded.
+ ///
+ /// This is equivalent to the [`Semaphore::add_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
+ pub fn add_permits(&self, n: usize) {
+ self.semaphore.add_permits(n);
+ }
+}
+
+impl Stream for PollSemaphore {
+ type Item = OwnedSemaphorePermit;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ Pin::into_inner(self).poll_acquire(cx)
+ }
+}
+
+impl Clone for PollSemaphore {
+ fn clone(&self) -> PollSemaphore {
+ PollSemaphore::new(self.clone_inner())
+ }
+}
+
+impl fmt::Debug for PollSemaphore {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PollSemaphore")
+ .field("semaphore", &self.semaphore)
+ .finish()
+ }
+}
+
+impl AsRef<Semaphore> for PollSemaphore {
+ fn as_ref(&self) -> &Semaphore {
+ &self.semaphore
+ }
+}
diff --git a/src/sync/reusable_box.rs b/src/sync/reusable_box.rs
new file mode 100644
index 0000000..1b8ef60
--- /dev/null
+++ b/src/sync/reusable_box.rs
@@ -0,0 +1,171 @@
+use std::alloc::Layout;
+use std::fmt;
+use std::future::Future;
+use std::marker::PhantomData;
+use std::mem::{self, ManuallyDrop};
+use std::pin::Pin;
+use std::ptr;
+use std::task::{Context, Poll};
+
+/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`.
+///
+/// This type lets you replace the future stored in the box without
+/// reallocating when the size and alignment permits this.
+pub struct ReusableBoxFuture<'a, T> {
+ boxed: Pin<Box<dyn Future<Output = T> + Send + 'a>>,
+}
+
+impl<'a, T> ReusableBoxFuture<'a, T> {
+ /// Create a new `ReusableBoxFuture<T>` containing the provided future.
+ pub fn new<F>(future: F) -> Self
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ Self {
+ boxed: Box::pin(future),
+ }
+ }
+
+ /// Replace the future currently stored in this box.
+ ///
+ /// This reallocates if and only if the layout of the provided future is
+ /// different from the layout of the currently stored future.
+ pub fn set<F>(&mut self, future: F)
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ if let Err(future) = self.try_set(future) {
+ *self = Self::new(future);
+ }
+ }
+
+ /// Replace the future currently stored in this box.
+ ///
+ /// This function never reallocates, but returns an error if the provided
+ /// future has a different size or alignment from the currently stored
+ /// future.
+ pub fn try_set<F>(&mut self, future: F) -> Result<(), F>
+ where
+ F: Future<Output = T> + Send + 'a,
+ {
+ // If we try to inline the contents of this function, the type checker complains because
+ // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an
+ // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound
+ // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`.
+ #[inline(always)]
+ fn real_try_set<'a, F>(
+ this: &mut ReusableBoxFuture<'a, F::Output>,
+ future: F,
+ ) -> Result<(), F>
+ where
+ F: Future + Send + 'a,
+ {
+ // future::Pending<T> is a ZST so this never allocates.
+ let boxed = mem::replace(&mut this.boxed, Box::pin(Pending(PhantomData)));
+ reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed))
+ }
+
+ real_try_set(self, future)
+ }
+
+ /// Get a pinned reference to the underlying future.
+ pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
+ self.boxed.as_mut()
+ }
+
+ /// Poll the future stored inside this box.
+ pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
+ self.get_pin().poll(cx)
+ }
+}
+
+impl<T> Future for ReusableBoxFuture<'_, T> {
+ type Output = T;
+
+ /// Poll the future stored inside this box.
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
+ Pin::into_inner(self).get_pin().poll(cx)
+ }
+}
+
+// The only method called on self.boxed is poll, which takes &mut self, so this
+// struct being Sync does not permit any invalid access to the Future, even if
+// the future is not Sync.
+unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {}
+
+impl<T> fmt::Debug for ReusableBoxFuture<'_, T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReusableBoxFuture").finish()
+ }
+}
+
+fn reuse_pin_box<T: ?Sized, U, O, F>(boxed: Pin<Box<T>>, new_value: U, callback: F) -> Result<O, U>
+where
+ F: FnOnce(Box<U>) -> O,
+{
+ let layout = Layout::for_value::<T>(&*boxed);
+ if layout != Layout::new::<U>() {
+ return Err(new_value);
+ }
+
+ // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we
+ // always drop the `T`.
+ let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) });
+
+ // When dropping the old value panics, we still want to call `callback` — so move the rest of
+ // the code into a guard type.
+ let guard = CallOnDrop::new(|| {
+ let raw: *mut U = raw.cast::<U>();
+ unsafe { raw.write(new_value) };
+
+ // SAFETY:
+ // - `T` and `U` have the same layout.
+ // - `raw` comes from a `Box` that uses the same allocator as this one.
+ // - `raw` points to a valid instance of `U` (we just wrote it in).
+ let boxed = unsafe { Box::from_raw(raw) };
+
+ callback(boxed)
+ });
+
+ // Drop the old value.
+ unsafe { ptr::drop_in_place(raw) };
+
+ // Run the rest of the code.
+ Ok(guard.call())
+}
+
+struct CallOnDrop<O, F: FnOnce() -> O> {
+ f: ManuallyDrop<F>,
+}
+
+impl<O, F: FnOnce() -> O> CallOnDrop<O, F> {
+ fn new(f: F) -> Self {
+ let f = ManuallyDrop::new(f);
+ Self { f }
+ }
+ fn call(self) -> O {
+ let mut this = ManuallyDrop::new(self);
+ let f = unsafe { ManuallyDrop::take(&mut this.f) };
+ f()
+ }
+}
+
+impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> {
+ fn drop(&mut self) {
+ let f = unsafe { ManuallyDrop::take(&mut self.f) };
+ f();
+ }
+}
+
+/// The same as `std::future::Pending<T>`; we can't use that type directly because on rustc
+/// versions <1.60 it didn't unconditionally implement `Send`.
+// FIXME: use `std::future::Pending<T>` once the MSRV is >=1.60
+struct Pending<T>(PhantomData<fn() -> T>);
+
+impl<T> Future for Pending<T> {
+ type Output = T;
+
+ fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
+ Poll::Pending
+ }
+}
diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs
new file mode 100644
index 0000000..b57503d
--- /dev/null
+++ b/src/sync/tests/loom_cancellation_token.rs
@@ -0,0 +1,176 @@
+use crate::sync::CancellationToken;
+
+use loom::{future::block_on, thread};
+use tokio_test::assert_ok;
+
+#[test]
+fn cancel_token() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ token1.cancelled().await;
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ token.cancel();
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ });
+}
+
+#[test]
+fn cancel_token_owned() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ token1.cancelled_owned().await;
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ token.cancel();
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ });
+}
+
+#[test]
+fn cancel_with_child() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+ let token2 = token.clone();
+ let child_token = token.child_token();
+
+ let th1 = thread::spawn(move || {
+ block_on(async {
+ token1.cancelled().await;
+ });
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ block_on(async {
+ child_token.cancelled().await;
+ });
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_token_no_child() {
+ loom::model(|| {
+ let token = CancellationToken::new();
+ let token1 = token.clone();
+ let token2 = token.clone();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ drop(token2);
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(token);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_token_with_children() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let child_token1 = token1.child_token();
+ let child_token2 = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ drop(child_token1);
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(child_token2);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn drop_and_cancel_token() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let token2 = token1.clone();
+ let child_token = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ drop(child_token);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn cancel_parent_and_child() {
+ loom::model(|| {
+ let token1 = CancellationToken::new();
+ let token2 = token1.clone();
+ let child_token = token1.child_token();
+
+ let th1 = thread::spawn(move || {
+ drop(token1);
+ });
+
+ let th2 = thread::spawn(move || {
+ token2.cancel();
+ });
+
+ let th3 = thread::spawn(move || {
+ child_token.cancel();
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/src/sync/tests/mod.rs
@@ -0,0 +1 @@
+
diff --git a/src/task/join_map.rs b/src/task/join_map.rs
new file mode 100644
index 0000000..c6bf5bc
--- /dev/null
+++ b/src/task/join_map.rs
@@ -0,0 +1,807 @@
+use hashbrown::hash_map::RawEntryMut;
+use hashbrown::HashMap;
+use std::borrow::Borrow;
+use std::collections::hash_map::RandomState;
+use std::fmt;
+use std::future::Future;
+use std::hash::{BuildHasher, Hash, Hasher};
+use tokio::runtime::Handle;
+use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
+
+/// A collection of tasks spawned on a Tokio runtime, associated with hash map
+/// keys.
+///
+/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
+/// addition of a set of keys associated with each task. These keys allow
+/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
+/// `JoinMap` based on their keys, or [test whether a task corresponding to a
+/// given key exists][contains] in the `JoinMap`.
+///
+/// In addition, when tasks in the `JoinMap` complete, they will return the
+/// associated key along with the value returned by the task, if any.
+///
+/// A `JoinMap` can be used to await the completion of some or all of the tasks
+/// in the map. The map is not ordered, and the tasks will be returned in the
+/// order they complete.
+///
+/// All of the tasks must have the same return type `V`.
+///
+/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
+///
+/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
+/// documentation on unstable features][unstable] for details on how to enable
+/// Tokio's unstable features.
+///
+/// # Examples
+///
+/// Spawn multiple tasks and wait for them:
+///
+/// ```
+/// use tokio_util::task::JoinMap;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let mut map = JoinMap::new();
+///
+/// for i in 0..10 {
+/// // Spawn a task on the `JoinMap` with `i` as its key.
+/// map.spawn(i, async move { /* ... */ });
+/// }
+///
+/// let mut seen = [false; 10];
+///
+/// // When a task completes, `join_next` returns the task's key along
+/// // with its output.
+/// while let Some((key, res)) = map.join_next().await {
+/// seen[key] = true;
+/// assert!(res.is_ok(), "task {} completed successfully!", key);
+/// }
+///
+/// for i in 0..10 {
+/// assert!(seen[i]);
+/// }
+/// }
+/// ```
+///
+/// Cancel tasks based on their keys:
+///
+/// ```
+/// use tokio_util::task::JoinMap;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let mut map = JoinMap::new();
+///
+/// map.spawn("hello world", async move { /* ... */ });
+/// map.spawn("goodbye world", async move { /* ... */});
+///
+/// // Look up the "goodbye world" task in the map and abort it.
+/// let aborted = map.abort("goodbye world");
+///
+/// // `JoinMap::abort` returns `true` if a task existed for the
+/// // provided key.
+/// assert!(aborted);
+///
+/// while let Some((key, res)) = map.join_next().await {
+/// if key == "goodbye world" {
+/// // The aborted task should complete with a cancelled `JoinError`.
+/// assert!(res.unwrap_err().is_cancelled());
+/// } else {
+/// // Other tasks should complete normally.
+/// assert!(res.is_ok());
+/// }
+/// }
+/// }
+/// ```
+///
+/// [`JoinSet`]: tokio::task::JoinSet
+/// [unstable]: tokio#unstable-features
+/// [abort]: fn@Self::abort
+/// [abort_matching]: fn@Self::abort_matching
+/// [contains]: fn@Self::contains_key
+#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
+pub struct JoinMap<K, V, S = RandomState> {
+ /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
+ /// indexed by their keys and task IDs.
+ ///
+ /// The [`Key`] type contains both the task's `K`-typed key provided when
+ /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
+ /// hash collisions when looking up tasks based on their pre-computed hash
+ /// (as stored in the `hashes_by_task` map).
+ tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
+
+ /// A map from task IDs to the hash of the key associated with that task.
+ ///
+ /// This map is used to perform reverse lookups of tasks in the
+ /// `tasks_by_key` map based on their task IDs. When a task terminates, the
+ /// ID is provided to us by the `JoinSet`, so we can look up the hash value
+ /// of that task's key, and then remove it from the `tasks_by_key` map using
+ /// the raw hash code, resolving collisions by comparing task IDs.
+ hashes_by_task: HashMap<Id, u64, S>,
+
+ /// The [`JoinSet`] that awaits the completion of tasks spawned on this
+ /// `JoinMap`.
+ tasks: JoinSet<V>,
+}
+
+/// A [`JoinMap`] key.
+///
+/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
+/// a task ID, so that hash collisions between `K`-typed keys can be resolved
+/// using either `K`'s `Eq` impl *or* by checking the task IDs.
+///
+/// This allows looking up a task using either an actual key (such as when the
+/// user queries the map with a key), *or* using a task ID and a hash (such as
+/// when removing completed tasks from the map).
+#[derive(Debug)]
+struct Key<K> {
+ key: K,
+ id: Id,
+}
+
+impl<K, V> JoinMap<K, V> {
+ /// Creates a new empty `JoinMap`.
+ ///
+ /// The `JoinMap` is initially created with a capacity of 0, so it will not
+ /// allocate until a task is first spawned on it.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ /// let map: JoinMap<&str, i32> = JoinMap::new();
+ /// ```
+ #[inline]
+ #[must_use]
+ pub fn new() -> Self {
+ Self::with_hasher(RandomState::new())
+ }
+
+ /// Creates an empty `JoinMap` with the specified capacity.
+ ///
+ /// The `JoinMap` will be able to hold at least `capacity` tasks without
+ /// reallocating.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
+ /// ```
+ #[inline]
+ #[must_use]
+ pub fn with_capacity(capacity: usize) -> Self {
+ JoinMap::with_capacity_and_hasher(capacity, Default::default())
+ }
+}
+
+impl<K, V, S: Clone> JoinMap<K, V, S> {
+ /// Creates an empty `JoinMap` which will use the given hash builder to hash
+ /// keys.
+ ///
+ /// The created map has the default initial capacity.
+ ///
+ /// Warning: `hash_builder` is normally randomly generated, and
+ /// is designed to allow `JoinMap` to be resistant to attacks that
+ /// cause many collisions and very poor performance. Setting it
+ /// manually using this function can expose a DoS attack vector.
+ ///
+ /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
+ /// the `JoinMap` to be useful, see its documentation for details.
+ #[inline]
+ #[must_use]
+ pub fn with_hasher(hash_builder: S) -> Self {
+ Self::with_capacity_and_hasher(0, hash_builder)
+ }
+
+ /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
+ /// to hash the keys.
+ ///
+ /// The `JoinMap` will be able to hold at least `capacity` elements without
+ /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
+ ///
+ /// Warning: `hash_builder` is normally randomly generated, and
+ /// is designed to allow HashMaps to be resistant to attacks that
+ /// cause many collisions and very poor performance. Setting it
+ /// manually using this function can expose a DoS attack vector.
+ ///
+ /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
+ /// the `JoinMap`to be useful, see its documentation for details.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// use tokio_util::task::JoinMap;
+ /// use std::collections::hash_map::RandomState;
+ ///
+ /// let s = RandomState::new();
+ /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
+ /// map.spawn(1, async move { "hello world!" });
+ /// # }
+ /// ```
+ #[inline]
+ #[must_use]
+ pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
+ Self {
+ tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
+ hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
+ tasks: JoinSet::new(),
+ }
+ }
+
+ /// Returns the number of tasks currently in the `JoinMap`.
+ pub fn len(&self) -> usize {
+ let len = self.tasks_by_key.len();
+ debug_assert_eq!(len, self.hashes_by_task.len());
+ len
+ }
+
+ /// Returns whether the `JoinMap` is empty.
+ pub fn is_empty(&self) -> bool {
+ let empty = self.tasks_by_key.is_empty();
+ debug_assert_eq!(empty, self.hashes_by_task.is_empty());
+ empty
+ }
+
+ /// Returns the number of tasks the map can hold without reallocating.
+ ///
+ /// This number is a lower bound; the `JoinMap` might be able to hold
+ /// more, but is guaranteed to be able to hold at least this many.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
+ /// assert!(map.capacity() >= 100);
+ /// ```
+ #[inline]
+ pub fn capacity(&self) -> usize {
+ let capacity = self.tasks_by_key.capacity();
+ debug_assert_eq!(capacity, self.hashes_by_task.capacity());
+ capacity
+ }
+}
+
+impl<K, V, S> JoinMap<K, V, S>
+where
+ K: Hash + Eq,
+ V: 'static,
+ S: BuildHasher,
+{
+ /// Spawn the provided task and store it in this `JoinMap` with the provided
+ /// key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// # Panics
+ ///
+ /// This method panics if called outside of a Tokio runtime.
+ ///
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn<F>(&mut self, key: K, task: F)
+ where
+ F: Future<Output = V>,
+ F: Send + 'static,
+ V: Send,
+ {
+ let task = self.tasks.spawn(task);
+ self.insert(key, task)
+ }
+
+ /// Spawn the provided task on the provided runtime and store it in this
+ /// `JoinMap` with the provided key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
+ where
+ F: Future<Output = V>,
+ F: Send + 'static,
+ V: Send,
+ {
+ let task = self.tasks.spawn_on(task, handle);
+ self.insert(key, task);
+ }
+
+ /// Spawn the provided task on the current [`LocalSet`] and store it in this
+ /// `JoinMap` with the provided key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// # Panics
+ ///
+ /// This method panics if it is called outside of a `LocalSet`.
+ ///
+ /// [`LocalSet`]: tokio::task::LocalSet
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn_local<F>(&mut self, key: K, task: F)
+ where
+ F: Future<Output = V>,
+ F: 'static,
+ {
+ let task = self.tasks.spawn_local(task);
+ self.insert(key, task);
+ }
+
+ /// Spawn the provided task on the provided [`LocalSet`] and store it in
+ /// this `JoinMap` with the provided key.
+ ///
+ /// If a task previously existed in the `JoinMap` for this key, that task
+ /// will be cancelled and replaced with the new one. The previous task will
+ /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
+ /// *not* return a cancelled [`JoinError`] for that task.
+ ///
+ /// [`LocalSet`]: tokio::task::LocalSet
+ /// [`join_next`]: Self::join_next
+ #[track_caller]
+ pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
+ where
+ F: Future<Output = V>,
+ F: 'static,
+ {
+ let task = self.tasks.spawn_local_on(task, local_set);
+ self.insert(key, task)
+ }
+
+ fn insert(&mut self, key: K, abort: AbortHandle) {
+ let hash = self.hash(&key);
+ let id = abort.id();
+ let map_key = Key { id, key };
+
+ // Insert the new key into the map of tasks by keys.
+ let entry = self
+ .tasks_by_key
+ .raw_entry_mut()
+ .from_hash(hash, |k| k.key == map_key.key);
+ match entry {
+ RawEntryMut::Occupied(mut occ) => {
+ // There was a previous task spawned with the same key! Cancel
+ // that task, and remove its ID from the map of hashes by task IDs.
+ let Key { id: prev_id, .. } = occ.insert_key(map_key);
+ occ.insert(abort).abort();
+ let _prev_hash = self.hashes_by_task.remove(&prev_id);
+ debug_assert_eq!(Some(hash), _prev_hash);
+ }
+ RawEntryMut::Vacant(vac) => {
+ vac.insert(map_key, abort);
+ }
+ };
+
+ // Associate the key's hash with this task's ID, for looking up tasks by ID.
+ let _prev = self.hashes_by_task.insert(id, hash);
+ debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
+ }
+
+ /// Waits until one of the tasks in the map completes and returns its
+ /// output, along with the key corresponding to that task.
+ ///
+ /// Returns `None` if the map is empty.
+ ///
+ /// # Cancel Safety
+ ///
+ /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
+ /// statement and some other branch completes first, it is guaranteed that no tasks were
+ /// removed from this `JoinMap`.
+ ///
+ /// # Returns
+ ///
+ /// This function returns:
+ ///
+ /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
+ /// completed. The `value` is the return value of that ask, and `key` is
+ /// the key associated with the task.
+ /// * `Some((key, Err(err))` if one of the tasks in this JoinMap` has
+ /// panicked or been aborted. `key` is the key associated with the task
+ /// that panicked or was aborted.
+ /// * `None` if the `JoinMap` is empty.
+ ///
+ /// [`tokio::select!`]: tokio::select
+ pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
+ let (res, id) = match self.tasks.join_next_with_id().await {
+ Some(Ok((id, output))) => (Ok(output), id),
+ Some(Err(e)) => {
+ let id = e.id();
+ (Err(e), id)
+ }
+ None => return None,
+ };
+ let key = self.remove_by_id(id)?;
+ Some((key, res))
+ }
+
+ /// Aborts all tasks and waits for them to finish shutting down.
+ ///
+ /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
+ /// a loop until it returns `None`.
+ ///
+ /// This method ignores any panics in the tasks shutting down. When this call returns, the
+ /// `JoinMap` will be empty.
+ ///
+ /// [`abort_all`]: fn@Self::abort_all
+ /// [`join_next`]: fn@Self::join_next
+ pub async fn shutdown(&mut self) {
+ self.abort_all();
+ while self.join_next().await.is_some() {}
+ }
+
+ /// Abort the task corresponding to the provided `key`.
+ ///
+ /// If this `JoinMap` contains a task corresponding to `key`, this method
+ /// will abort that task and return `true`. Otherwise, if no task exists for
+ /// `key`, this method returns `false`.
+ ///
+ /// # Examples
+ ///
+ /// Aborting a task by key:
+ ///
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut map = JoinMap::new();
+ ///
+ /// map.spawn("hello world", async move { /* ... */ });
+ /// map.spawn("goodbye world", async move { /* ... */});
+ ///
+ /// // Look up the "goodbye world" task in the map and abort it.
+ /// map.abort("goodbye world");
+ ///
+ /// while let Some((key, res)) = map.join_next().await {
+ /// if key == "goodbye world" {
+ /// // The aborted task should complete with a cancelled `JoinError`.
+ /// assert!(res.unwrap_err().is_cancelled());
+ /// } else {
+ /// // Other tasks should complete normally.
+ /// assert!(res.is_ok());
+ /// }
+ /// }
+ /// # }
+ /// ```
+ ///
+ /// `abort` returns `true` if a task was aborted:
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut map = JoinMap::new();
+ ///
+ /// map.spawn("hello world", async move { /* ... */ });
+ /// map.spawn("goodbye world", async move { /* ... */});
+ ///
+ /// // A task for the key "goodbye world" should exist in the map:
+ /// assert!(map.abort("goodbye world"));
+ ///
+ /// // Aborting a key that does not exist will return `false`:
+ /// assert!(!map.abort("goodbye universe"));
+ /// # }
+ /// ```
+ pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
+ where
+ Q: Hash + Eq,
+ K: Borrow<Q>,
+ {
+ match self.get_by_key(key) {
+ Some((_, handle)) => {
+ handle.abort();
+ true
+ }
+ None => false,
+ }
+ }
+
+ /// Aborts all tasks with keys matching `predicate`.
+ ///
+ /// `predicate` is a function called with a reference to each key in the
+ /// map. If it returns `true` for a given key, the corresponding task will
+ /// be cancelled.
+ ///
+ /// # Examples
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// # // use the current thread rt so that spawned tasks don't
+ /// # // complete in the background before they can be aborted.
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// # async fn main() {
+ /// let mut map = JoinMap::new();
+ ///
+ /// map.spawn("hello world", async move {
+ /// // ...
+ /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
+ /// });
+ /// map.spawn("goodbye world", async move {
+ /// // ...
+ /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
+ /// });
+ /// map.spawn("hello san francisco", async move {
+ /// // ...
+ /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
+ /// });
+ /// map.spawn("goodbye universe", async move {
+ /// // ...
+ /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
+ /// });
+ ///
+ /// // Abort all tasks whose keys begin with "goodbye"
+ /// map.abort_matching(|key| key.starts_with("goodbye"));
+ ///
+ /// let mut seen = 0;
+ /// while let Some((key, res)) = map.join_next().await {
+ /// seen += 1;
+ /// if key.starts_with("goodbye") {
+ /// // The aborted task should complete with a cancelled `JoinError`.
+ /// assert!(res.unwrap_err().is_cancelled());
+ /// } else {
+ /// // Other tasks should complete normally.
+ /// assert!(key.starts_with("hello"));
+ /// assert!(res.is_ok());
+ /// }
+ /// }
+ ///
+ /// // All spawned tasks should have completed.
+ /// assert_eq!(seen, 4);
+ /// # }
+ /// ```
+ pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
+ // Note: this method iterates over the tasks and keys *without* removing
+ // any entries, so that the keys from aborted tasks can still be
+ // returned when calling `join_next` in the future.
+ for (Key { ref key, .. }, task) in &self.tasks_by_key {
+ if predicate(key) {
+ task.abort();
+ }
+ }
+ }
+
+ /// Returns `true` if this `JoinMap` contains a task for the provided key.
+ ///
+ /// If the task has completed, but its output hasn't yet been consumed by a
+ /// call to [`join_next`], this method will still return `true`.
+ ///
+ /// [`join_next`]: fn@Self::join_next
+ pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
+ where
+ Q: Hash + Eq,
+ K: Borrow<Q>,
+ {
+ self.get_by_key(key).is_some()
+ }
+
+ /// Returns `true` if this `JoinMap` contains a task with the provided
+ /// [task ID].
+ ///
+ /// If the task has completed, but its output hasn't yet been consumed by a
+ /// call to [`join_next`], this method will still return `true`.
+ ///
+ /// [`join_next`]: fn@Self::join_next
+ /// [task ID]: tokio::task::Id
+ pub fn contains_task(&self, task: &Id) -> bool {
+ self.get_by_id(task).is_some()
+ }
+
+ /// Reserves capacity for at least `additional` more tasks to be spawned
+ /// on this `JoinMap` without reallocating for the map of task keys. The
+ /// collection may reserve more space to avoid frequent reallocations.
+ ///
+ /// Note that spawning a task will still cause an allocation for the task
+ /// itself.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the new allocation size overflows [`usize`].
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// let mut map: JoinMap<&str, i32> = JoinMap::new();
+ /// map.reserve(10);
+ /// ```
+ #[inline]
+ pub fn reserve(&mut self, additional: usize) {
+ self.tasks_by_key.reserve(additional);
+ self.hashes_by_task.reserve(additional);
+ }
+
+ /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
+ /// down as much as possible while maintaining the internal rules
+ /// and possibly leaving some space in accordance with the resize policy.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
+ /// map.spawn(1, async move { 2 });
+ /// map.spawn(3, async move { 4 });
+ /// assert!(map.capacity() >= 100);
+ /// map.shrink_to_fit();
+ /// assert!(map.capacity() >= 2);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn shrink_to_fit(&mut self) {
+ self.hashes_by_task.shrink_to_fit();
+ self.tasks_by_key.shrink_to_fit();
+ }
+
+ /// Shrinks the capacity of the map with a lower limit. It will drop
+ /// down no lower than the supplied limit while maintaining the internal rules
+ /// and possibly leaving some space in accordance with the resize policy.
+ ///
+ /// If the current capacity is less than the lower limit, this is a no-op.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// use tokio_util::task::JoinMap;
+ ///
+ /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
+ /// map.spawn(1, async move { 2 });
+ /// map.spawn(3, async move { 4 });
+ /// assert!(map.capacity() >= 100);
+ /// map.shrink_to(10);
+ /// assert!(map.capacity() >= 10);
+ /// map.shrink_to(0);
+ /// assert!(map.capacity() >= 2);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn shrink_to(&mut self, min_capacity: usize) {
+ self.hashes_by_task.shrink_to(min_capacity);
+ self.tasks_by_key.shrink_to(min_capacity)
+ }
+
+ /// Look up a task in the map by its key, returning the key and abort handle.
+ fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
+ where
+ Q: Hash + Eq,
+ K: Borrow<Q>,
+ {
+ let hash = self.hash(key);
+ self.tasks_by_key
+ .raw_entry()
+ .from_hash(hash, |k| k.key.borrow() == key)
+ }
+
+ /// Look up a task in the map by its task ID, returning the key and abort handle.
+ fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
+ let hash = self.hashes_by_task.get(id)?;
+ self.tasks_by_key
+ .raw_entry()
+ .from_hash(*hash, |k| &k.id == id)
+ }
+
+ /// Remove a task from the map by ID, returning the key for that task.
+ fn remove_by_id(&mut self, id: Id) -> Option<K> {
+ // Get the hash for the given ID.
+ let hash = self.hashes_by_task.remove(&id)?;
+
+ // Remove the entry for that hash.
+ let entry = self
+ .tasks_by_key
+ .raw_entry_mut()
+ .from_hash(hash, |k| k.id == id);
+ let (Key { id: _key_id, key }, handle) = match entry {
+ RawEntryMut::Occupied(entry) => entry.remove_entry(),
+ _ => return None,
+ };
+ debug_assert_eq!(_key_id, id);
+ debug_assert_eq!(id, handle.id());
+ self.hashes_by_task.remove(&id);
+ Some(key)
+ }
+
+ /// Returns the hash for a given key.
+ #[inline]
+ fn hash<Q: ?Sized>(&self, key: &Q) -> u64
+ where
+ Q: Hash,
+ {
+ let mut hasher = self.tasks_by_key.hasher().build_hasher();
+ key.hash(&mut hasher);
+ hasher.finish()
+ }
+}
+
+impl<K, V, S> JoinMap<K, V, S>
+where
+ V: 'static,
+{
+ /// Aborts all tasks on this `JoinMap`.
+ ///
+ /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
+ /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
+ pub fn abort_all(&mut self) {
+ self.tasks.abort_all()
+ }
+
+ /// Removes all tasks from this `JoinMap` without aborting them.
+ ///
+ /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
+ /// is dropped. They may still be aborted by key.
+ pub fn detach_all(&mut self) {
+ self.tasks.detach_all();
+ self.tasks_by_key.clear();
+ self.hashes_by_task.clear();
+ }
+}
+
+// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
+// Debug`, since no value is ever actually stored in the map.
+impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ // format the task keys and abort handles a little nicer by just
+ // printing the key and task ID pairs, without format the `Key` struct
+ // itself or the `AbortHandle`, which would just format the task's ID
+ // again.
+ struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
+ impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_map()
+ .entries(self.0.keys().map(|Key { key, id }| (key, id)))
+ .finish()
+ }
+ }
+
+ f.debug_struct("JoinMap")
+ // The `tasks_by_key` map is the only one that contains information
+ // that's really worth formatting for the user, since it contains
+ // the tasks' keys and IDs. The other fields are basically
+ // implementation details.
+ .field("tasks", &KeySet(&self.tasks_by_key))
+ .finish()
+ }
+}
+
+impl<K, V> Default for JoinMap<K, V> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// === impl Key ===
+
+impl<K: Hash> Hash for Key<K> {
+ // Don't include the task ID in the hash.
+ #[inline]
+ fn hash<H: Hasher>(&self, hasher: &mut H) {
+ self.key.hash(hasher);
+ }
+}
+
+// Because we override `Hash` for this type, we must also override the
+// `PartialEq` impl, so that all instances with the same hash are equal.
+impl<K: PartialEq> PartialEq for Key<K> {
+ #[inline]
+ fn eq(&self, other: &Self) -> bool {
+ self.key == other.key
+ }
+}
+
+impl<K: Eq> Eq for Key<K> {}
diff --git a/src/task/mod.rs b/src/task/mod.rs
new file mode 100644
index 0000000..de41dd5
--- /dev/null
+++ b/src/task/mod.rs
@@ -0,0 +1,12 @@
+//! Extra utilities for spawning tasks
+
+#[cfg(tokio_unstable)]
+mod join_map;
+#[cfg(not(target_os = "wasi"))]
+mod spawn_pinned;
+#[cfg(not(target_os = "wasi"))]
+pub use spawn_pinned::LocalPoolHandle;
+
+#[cfg(tokio_unstable)]
+#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))]
+pub use join_map::JoinMap;
diff --git a/src/task/spawn_pinned.rs b/src/task/spawn_pinned.rs
new file mode 100644
index 0000000..b4102ec
--- /dev/null
+++ b/src/task/spawn_pinned.rs
@@ -0,0 +1,436 @@
+use futures_util::future::{AbortHandle, Abortable};
+use std::fmt;
+use std::fmt::{Debug, Formatter};
+use std::future::Future;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use tokio::runtime::Builder;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
+use tokio::sync::oneshot;
+use tokio::task::{spawn_local, JoinHandle, LocalSet};
+
+/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
+///
+/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
+/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
+/// execute on the same thread) inside the Future you supply to the various spawn methods
+/// of `LocalPoolHandle`,
+///
+/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
+/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
+///
+/// # Examples
+///
+/// ```
+/// use std::rc::Rc;
+/// use tokio::{self, task };
+/// use tokio_util::task::LocalPoolHandle;
+///
+/// #[tokio::main(flavor = "current_thread")]
+/// async fn main() {
+/// let pool = LocalPoolHandle::new(5);
+///
+/// let output = pool.spawn_pinned(|| {
+/// // `data` is !Send + !Sync
+/// let data = Rc::new("local data");
+/// let data_clone = data.clone();
+///
+/// async move {
+/// task::spawn_local(async move {
+/// println!("{}", data_clone);
+/// });
+///
+/// data.to_string()
+/// }
+/// }).await.unwrap();
+/// println!("output: {}", output);
+/// }
+/// ```
+///
+#[derive(Clone)]
+pub struct LocalPoolHandle {
+ pool: Arc<LocalPool>,
+}
+
+impl LocalPoolHandle {
+ /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
+ /// pool via [`LocalPoolHandle::spawn_pinned`].
+ ///
+ /// # Panics
+ ///
+ /// Panics if the pool size is less than one.
+ #[track_caller]
+ pub fn new(pool_size: usize) -> LocalPoolHandle {
+ assert!(pool_size > 0);
+
+ let workers = (0..pool_size)
+ .map(|_| LocalWorkerHandle::new_worker())
+ .collect();
+
+ let pool = Arc::new(LocalPool { workers });
+
+ LocalPoolHandle { pool }
+ }
+
+ /// Returns the number of threads of the Pool.
+ #[inline]
+ pub fn num_threads(&self) -> usize {
+ self.pool.workers.len()
+ }
+
+ /// Returns the number of tasks scheduled on each worker. The indices of the
+ /// worker threads correspond to the indices of the returned `Vec`.
+ pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
+ self.pool
+ .workers
+ .iter()
+ .map(|worker| worker.task_count.load(Ordering::SeqCst))
+ .collect::<Vec<_>>()
+ }
+
+ /// Spawn a task onto a worker thread and pin it there so it can't be moved
+ /// off of the thread. Note that the future is not [`Send`], but the
+ /// [`FnOnce`] which creates it is.
+ ///
+ /// # Examples
+ /// ```
+ /// use std::rc::Rc;
+ /// use tokio_util::task::LocalPoolHandle;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// // Create the local pool
+ /// let pool = LocalPoolHandle::new(1);
+ ///
+ /// // Spawn a !Send future onto the pool and await it
+ /// let output = pool
+ /// .spawn_pinned(|| {
+ /// // Rc is !Send + !Sync
+ /// let local_data = Rc::new("test");
+ ///
+ /// // This future holds an Rc, so it is !Send
+ /// async move { local_data.to_string() }
+ /// })
+ /// .await
+ /// .unwrap();
+ ///
+ /// assert_eq!(output, "test");
+ /// }
+ /// ```
+ pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
+ where
+ F: FnOnce() -> Fut,
+ F: Send + 'static,
+ Fut: Future + 'static,
+ Fut::Output: Send + 'static,
+ {
+ self.pool
+ .spawn_pinned(create_task, WorkerChoice::LeastBurdened)
+ }
+
+ /// Differs from `spawn_pinned` only in that you can choose a specific worker thread
+ /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
+ /// number of tasks scheduled.
+ ///
+ /// A worker thread is chosen by index. Indices are 0 based and the largest index
+ /// is given by `num_threads() - 1`
+ ///
+ /// # Panics
+ ///
+ /// This method panics if the index is out of bounds.
+ ///
+ /// # Examples
+ ///
+ /// This method can be used to spawn a task on all worker threads of the pool:
+ ///
+ /// ```
+ /// use tokio_util::task::LocalPoolHandle;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// const NUM_WORKERS: usize = 3;
+ /// let pool = LocalPoolHandle::new(NUM_WORKERS);
+ /// let handles = (0..pool.num_threads())
+ /// .map(|worker_idx| {
+ /// pool.spawn_pinned_by_idx(
+ /// || {
+ /// async {
+ /// "test"
+ /// }
+ /// },
+ /// worker_idx,
+ /// )
+ /// })
+ /// .collect::<Vec<_>>();
+ ///
+ /// for handle in handles {
+ /// handle.await.unwrap();
+ /// }
+ /// }
+ /// ```
+ ///
+ #[track_caller]
+ pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
+ where
+ F: FnOnce() -> Fut,
+ F: Send + 'static,
+ Fut: Future + 'static,
+ Fut::Output: Send + 'static,
+ {
+ self.pool
+ .spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
+ }
+}
+
+impl Debug for LocalPoolHandle {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.write_str("LocalPoolHandle")
+ }
+}
+
+enum WorkerChoice {
+ LeastBurdened,
+ ByIdx(usize),
+}
+
+struct LocalPool {
+ workers: Vec<LocalWorkerHandle>,
+}
+
+impl LocalPool {
+ /// Spawn a `?Send` future onto a worker
+ #[track_caller]
+ fn spawn_pinned<F, Fut>(
+ &self,
+ create_task: F,
+ worker_choice: WorkerChoice,
+ ) -> JoinHandle<Fut::Output>
+ where
+ F: FnOnce() -> Fut,
+ F: Send + 'static,
+ Fut: Future + 'static,
+ Fut::Output: Send + 'static,
+ {
+ let (sender, receiver) = oneshot::channel();
+ let (worker, job_guard) = match worker_choice {
+ WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
+ WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
+ };
+ let worker_spawner = worker.spawner.clone();
+
+ // Spawn a future onto the worker's runtime so we can immediately return
+ // a join handle.
+ worker.runtime_handle.spawn(async move {
+ // Move the job guard into the task
+ let _job_guard = job_guard;
+
+ // Propagate aborts via Abortable/AbortHandle
+ let (abort_handle, abort_registration) = AbortHandle::new_pair();
+ let _abort_guard = AbortGuard(abort_handle);
+
+ // Inside the future we can't run spawn_local yet because we're not
+ // in the context of a LocalSet. We need to send create_task to the
+ // LocalSet task for spawning.
+ let spawn_task = Box::new(move || {
+ // Once we're in the LocalSet context we can call spawn_local
+ let join_handle =
+ spawn_local(
+ async move { Abortable::new(create_task(), abort_registration).await },
+ );
+
+ // Send the join handle back to the spawner. If sending fails,
+ // we assume the parent task was canceled, so cancel this task
+ // as well.
+ if let Err(join_handle) = sender.send(join_handle) {
+ join_handle.abort()
+ }
+ });
+
+ // Send the callback to the LocalSet task
+ if let Err(e) = worker_spawner.send(spawn_task) {
+ // Propagate the error as a panic in the join handle.
+ panic!("Failed to send job to worker: {}", e);
+ }
+
+ // Wait for the task's join handle
+ let join_handle = match receiver.await {
+ Ok(handle) => handle,
+ Err(e) => {
+ // We sent the task successfully, but failed to get its
+ // join handle... We assume something happened to the worker
+ // and the task was not spawned. Propagate the error as a
+ // panic in the join handle.
+ panic!("Worker failed to send join handle: {}", e);
+ }
+ };
+
+ // Wait for the task to complete
+ let join_result = join_handle.await;
+
+ match join_result {
+ Ok(Ok(output)) => output,
+ Ok(Err(_)) => {
+ // Pinned task was aborted. But that only happens if this
+ // task is aborted. So this is an impossible branch.
+ unreachable!(
+ "Reaching this branch means this task was previously \
+ aborted but it continued running anyways"
+ )
+ }
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else if e.is_cancelled() {
+ // No one else should have the join handle, so this is
+ // unexpected. Forward this error as a panic in the join
+ // handle.
+ panic!("spawn_pinned task was canceled: {}", e);
+ } else {
+ // Something unknown happened (not a panic or
+ // cancellation). Forward this error as a panic in the
+ // join handle.
+ panic!("spawn_pinned task failed: {}", e);
+ }
+ }
+ }
+ })
+ }
+
+ /// Find the worker with the least number of tasks, increment its task
+ /// count, and return its handle. Make sure to actually spawn a task on
+ /// the worker so the task count is kept consistent with load.
+ ///
+ /// A job count guard is also returned to ensure the task count gets
+ /// decremented when the job is done.
+ fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
+ loop {
+ let (worker, task_count) = self
+ .workers
+ .iter()
+ .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
+ .min_by_key(|&(_, count)| count)
+ .expect("There must be more than one worker");
+
+ // Make sure the task count hasn't changed since when we choose this
+ // worker. Otherwise, restart the search.
+ if worker
+ .task_count
+ .compare_exchange(
+ task_count,
+ task_count + 1,
+ Ordering::SeqCst,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
+ }
+ }
+ }
+
+ #[track_caller]
+ fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
+ let worker = &self.workers[idx];
+ worker.task_count.fetch_add(1, Ordering::SeqCst);
+
+ (worker, JobCountGuard(Arc::clone(&worker.task_count)))
+ }
+}
+
+/// Automatically decrements a worker's job count when a job finishes (when
+/// this gets dropped).
+struct JobCountGuard(Arc<AtomicUsize>);
+
+impl Drop for JobCountGuard {
+ fn drop(&mut self) {
+ // Decrement the job count
+ let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
+ debug_assert!(previous_value >= 1);
+ }
+}
+
+/// Calls abort on the handle when dropped.
+struct AbortGuard(AbortHandle);
+
+impl Drop for AbortGuard {
+ fn drop(&mut self) {
+ self.0.abort();
+ }
+}
+
+type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
+
+struct LocalWorkerHandle {
+ runtime_handle: tokio::runtime::Handle,
+ spawner: UnboundedSender<PinnedFutureSpawner>,
+ task_count: Arc<AtomicUsize>,
+}
+
+impl LocalWorkerHandle {
+ /// Create a new worker for executing pinned tasks
+ fn new_worker() -> LocalWorkerHandle {
+ let (sender, receiver) = unbounded_channel();
+ let runtime = Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .expect("Failed to start a pinned worker thread runtime");
+ let runtime_handle = runtime.handle().clone();
+ let task_count = Arc::new(AtomicUsize::new(0));
+ let task_count_clone = Arc::clone(&task_count);
+
+ std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
+
+ LocalWorkerHandle {
+ runtime_handle,
+ spawner: sender,
+ task_count,
+ }
+ }
+
+ fn run(
+ runtime: tokio::runtime::Runtime,
+ mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
+ task_count: Arc<AtomicUsize>,
+ ) {
+ let local_set = LocalSet::new();
+ local_set.block_on(&runtime, async {
+ while let Some(spawn_task) = task_receiver.recv().await {
+ // Calls spawn_local(future)
+ (spawn_task)();
+ }
+ });
+
+ // If there are any tasks on the runtime associated with a LocalSet task
+ // that has already completed, but whose output has not yet been
+ // reported, let that task complete.
+ //
+ // Since the task_count is decremented when the runtime task exits,
+ // reading that counter lets us know if any such tasks completed during
+ // the call to `block_on`.
+ //
+ // Tasks on the LocalSet can't complete during this loop since they're
+ // stored on the LocalSet and we aren't accessing it.
+ let mut previous_task_count = task_count.load(Ordering::SeqCst);
+ loop {
+ // This call will also run tasks spawned on the runtime.
+ runtime.block_on(tokio::task::yield_now());
+ let new_task_count = task_count.load(Ordering::SeqCst);
+ if new_task_count == previous_task_count {
+ break;
+ } else {
+ previous_task_count = new_task_count;
+ }
+ }
+
+ // It's now no longer possible for a task on the runtime to be
+ // associated with a LocalSet task that has completed. Drop both the
+ // LocalSet and runtime to let tasks on the runtime be cancelled if and
+ // only if they are still on the LocalSet.
+ //
+ // Drop the LocalSet task first so that anyone awaiting the runtime
+ // JoinHandle will see the cancelled error after the LocalSet task
+ // destructor has completed.
+ drop(local_set);
+ drop(runtime);
+ }
+}
diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs
new file mode 100644
index 0000000..ee66adb
--- /dev/null
+++ b/src/time/delay_queue.rs
@@ -0,0 +1,1271 @@
+//! A queue of delayed elements.
+//!
+//! See [`DelayQueue`] for more details.
+//!
+//! [`DelayQueue`]: struct@DelayQueue
+
+use crate::time::wheel::{self, Wheel};
+
+use futures_core::ready;
+use tokio::time::{sleep_until, Duration, Instant, Sleep};
+
+use core::ops::{Index, IndexMut};
+use slab::Slab;
+use std::cmp;
+use std::collections::HashMap;
+use std::convert::From;
+use std::fmt;
+use std::fmt::Debug;
+use std::future::Future;
+use std::marker::PhantomData;
+use std::pin::Pin;
+use std::task::{self, Poll, Waker};
+
+/// A queue of delayed elements.
+///
+/// Once an element is inserted into the `DelayQueue`, it is yielded once the
+/// specified deadline has been reached.
+///
+/// # Usage
+///
+/// Elements are inserted into `DelayQueue` using the [`insert`] or
+/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is
+/// returned. The key is used to remove the entry or to change the deadline at
+/// which it should be yielded back.
+///
+/// Once delays have been configured, the `DelayQueue` is used via its
+/// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its
+/// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the
+/// current task will be notified once the deadline has been reached.
+///
+/// # `Stream` implementation
+///
+/// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have
+/// expired, no items are returned. In this case, `Poll::Pending` is returned and the
+/// current task is registered to be notified once the next item's delay has
+/// expired.
+///
+/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll`
+/// returns `Poll::Ready(None)`. This indicates that the stream has reached an end.
+/// However, if a new item is inserted *after*, `poll` will once again start
+/// returning items or `Poll::Pending`.
+///
+/// Items are returned ordered by their expirations. Items that are configured
+/// to expire first will be returned first. There are no ordering guarantees
+/// for items configured to expire at the same instant. Also note that delays are
+/// rounded to the closest millisecond.
+///
+/// # Implementation
+///
+/// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally
+/// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same
+/// performance and scalability benefits.
+///
+/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation,
+/// and allows reuse of the memory allocated for expired entires.
+///
+/// Capacity can be checked using [`capacity`] and allocated preemptively by using
+/// the [`reserve`] method.
+///
+/// # Usage
+///
+/// Using `DelayQueue` to manage cache entries.
+///
+/// ```rust,no_run
+/// use tokio_util::time::{DelayQueue, delay_queue};
+///
+/// use futures::ready;
+/// use std::collections::HashMap;
+/// use std::task::{Context, Poll};
+/// use std::time::Duration;
+/// # type CacheKey = String;
+/// # type Value = String;
+///
+/// struct Cache {
+/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>,
+/// expirations: DelayQueue<CacheKey>,
+/// }
+///
+/// const TTL_SECS: u64 = 30;
+///
+/// impl Cache {
+/// fn insert(&mut self, key: CacheKey, value: Value) {
+/// let delay = self.expirations
+/// .insert(key.clone(), Duration::from_secs(TTL_SECS));
+///
+/// self.entries.insert(key, (value, delay));
+/// }
+///
+/// fn get(&self, key: &CacheKey) -> Option<&Value> {
+/// self.entries.get(key)
+/// .map(|&(ref v, _)| v)
+/// }
+///
+/// fn remove(&mut self, key: &CacheKey) {
+/// if let Some((_, cache_key)) = self.entries.remove(key) {
+/// self.expirations.remove(&cache_key);
+/// }
+/// }
+///
+/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> {
+/// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) {
+/// self.entries.remove(entry.get_ref());
+/// }
+///
+/// Poll::Ready(())
+/// }
+/// }
+/// ```
+///
+/// [`insert`]: method@Self::insert
+/// [`insert_at`]: method@Self::insert_at
+/// [`Key`]: struct@Key
+/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html
+/// [`poll_expired`]: method@Self::poll_expired
+/// [`Stream::poll_expired`]: method@Self::poll_expired
+/// [`DelayQueue`]: struct@DelayQueue
+/// [`sleep`]: fn@tokio::time::sleep
+/// [`slab`]: slab
+/// [`capacity`]: method@Self::capacity
+/// [`reserve`]: method@Self::reserve
+#[derive(Debug)]
+pub struct DelayQueue<T> {
+ /// Stores data associated with entries
+ slab: SlabStorage<T>,
+
+ /// Lookup structure tracking all delays in the queue
+ wheel: Wheel<Stack<T>>,
+
+ /// Delays that were inserted when already expired. These cannot be stored
+ /// in the wheel
+ expired: Stack<T>,
+
+ /// Delay expiring when the *first* item in the queue expires
+ delay: Option<Pin<Box<Sleep>>>,
+
+ /// Wheel polling state
+ wheel_now: u64,
+
+ /// Instant at which the timer starts
+ start: Instant,
+
+ /// Waker that is invoked when we potentially need to reset the timer.
+ /// Because we lazily create the timer when the first entry is created, we
+ /// need to awaken any poller that polled us before that point.
+ waker: Option<Waker>,
+}
+
+#[derive(Default)]
+struct SlabStorage<T> {
+ inner: Slab<Data<T>>,
+
+ // A `compact` call requires a re-mapping of the `Key`s that were changed
+ // during the `compact` call of the `slab`. Since the keys that were given out
+ // cannot be changed retroactively we need to keep track of these re-mappings.
+ // The keys of `key_map` correspond to the old keys that were given out and
+ // the values to the `Key`s that were re-mapped by the `compact` call.
+ key_map: HashMap<Key, KeyInternal>,
+
+ // Index used to create new keys to hand out.
+ next_key_index: usize,
+
+ // Whether `compact` has been called, necessary in order to decide whether
+ // to include keys in `key_map`.
+ compact_called: bool,
+}
+
+impl<T> SlabStorage<T> {
+ pub(crate) fn with_capacity(capacity: usize) -> SlabStorage<T> {
+ SlabStorage {
+ inner: Slab::with_capacity(capacity),
+ key_map: HashMap::new(),
+ next_key_index: 0,
+ compact_called: false,
+ }
+ }
+
+ // Inserts data into the inner slab and re-maps keys if necessary
+ pub(crate) fn insert(&mut self, val: Data<T>) -> Key {
+ let mut key = KeyInternal::new(self.inner.insert(val));
+ let key_contained = self.key_map.contains_key(&key.into());
+
+ if key_contained {
+ // It's possible that a `compact` call creates capacity in `self.inner` in
+ // such a way that a `self.inner.insert` call creates a `key` which was
+ // previously given out during an `insert` call prior to the `compact` call.
+ // If `key` is contained in `self.key_map`, we have encountered this exact situation,
+ // We need to create a new key `key_to_give_out` and include the relation
+ // `key_to_give_out` -> `key` in `self.key_map`.
+ let key_to_give_out = self.create_new_key();
+ assert!(!self.key_map.contains_key(&key_to_give_out.into()));
+ self.key_map.insert(key_to_give_out.into(), key);
+ key = key_to_give_out;
+ } else if self.compact_called {
+ // Include an identity mapping in `self.key_map` in order to allow us to
+ // panic if a key that was handed out is removed more than once.
+ self.key_map.insert(key.into(), key);
+ }
+
+ key.into()
+ }
+
+ // Re-map the key in case compact was previously called.
+ // Note: Since we include identity mappings in key_map after compact was called,
+ // we have information about all keys that were handed out. In the case in which
+ // compact was called and we try to remove a Key that was previously removed
+ // we can detect invalid keys if no key is found in `key_map`. This is necessary
+ // in order to prevent situations in which a previously removed key
+ // corresponds to a re-mapped key internally and which would then be incorrectly
+ // removed from the slab.
+ //
+ // Example to illuminate this problem:
+ //
+ // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we
+ // were to remove 1 again, we would not find it inside `key_map` anymore.
+ // If we were to imply from this that no re-mapping was necessary, we would
+ // incorrectly remove 1 from `self.slab.inner`, which corresponds to the
+ // handed-out key 2.
+ pub(crate) fn remove(&mut self, key: &Key) -> Data<T> {
+ let remapped_key = if self.compact_called {
+ match self.key_map.remove(key) {
+ Some(key_internal) => key_internal,
+ None => panic!("invalid key"),
+ }
+ } else {
+ (*key).into()
+ };
+
+ self.inner.remove(remapped_key.index)
+ }
+
+ pub(crate) fn shrink_to_fit(&mut self) {
+ self.inner.shrink_to_fit();
+ self.key_map.shrink_to_fit();
+ }
+
+ pub(crate) fn compact(&mut self) {
+ if !self.compact_called {
+ for (key, _) in self.inner.iter() {
+ self.key_map.insert(Key::new(key), KeyInternal::new(key));
+ }
+ }
+
+ let mut remapping = HashMap::new();
+ self.inner.compact(|_, from, to| {
+ remapping.insert(from, to);
+ true
+ });
+
+ // At this point `key_map` contains a mapping for every element.
+ for internal_key in self.key_map.values_mut() {
+ if let Some(new_internal_key) = remapping.get(&internal_key.index) {
+ *internal_key = KeyInternal::new(*new_internal_key);
+ }
+ }
+
+ if self.key_map.capacity() > 2 * self.key_map.len() {
+ self.key_map.shrink_to_fit();
+ }
+
+ self.compact_called = true;
+ }
+
+ // Tries to re-map a `Key` that was given out to the user to its
+ // corresponding internal key.
+ fn remap_key(&self, key: &Key) -> Option<KeyInternal> {
+ let key_map = &self.key_map;
+ if self.compact_called {
+ key_map.get(key).copied()
+ } else {
+ Some((*key).into())
+ }
+ }
+
+ fn create_new_key(&mut self) -> KeyInternal {
+ while self.key_map.contains_key(&Key::new(self.next_key_index)) {
+ self.next_key_index = self.next_key_index.wrapping_add(1);
+ }
+
+ KeyInternal::new(self.next_key_index)
+ }
+
+ pub(crate) fn len(&self) -> usize {
+ self.inner.len()
+ }
+
+ pub(crate) fn capacity(&self) -> usize {
+ self.inner.capacity()
+ }
+
+ pub(crate) fn clear(&mut self) {
+ self.inner.clear();
+ self.key_map.clear();
+ self.compact_called = false;
+ }
+
+ pub(crate) fn reserve(&mut self, additional: usize) {
+ self.inner.reserve(additional);
+
+ if self.compact_called {
+ self.key_map.reserve(additional);
+ }
+ }
+
+ pub(crate) fn is_empty(&self) -> bool {
+ self.inner.is_empty()
+ }
+
+ pub(crate) fn contains(&self, key: &Key) -> bool {
+ let remapped_key = self.remap_key(key);
+
+ match remapped_key {
+ Some(internal_key) => self.inner.contains(internal_key.index),
+ None => false,
+ }
+ }
+}
+
+impl<T> fmt::Debug for SlabStorage<T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ if fmt.alternate() {
+ fmt.debug_map().entries(self.inner.iter()).finish()
+ } else {
+ fmt.debug_struct("Slab")
+ .field("len", &self.len())
+ .field("cap", &self.capacity())
+ .finish()
+ }
+ }
+}
+
+impl<T> Index<Key> for SlabStorage<T> {
+ type Output = Data<T>;
+
+ fn index(&self, key: Key) -> &Self::Output {
+ let remapped_key = self.remap_key(&key);
+
+ match remapped_key {
+ Some(internal_key) => &self.inner[internal_key.index],
+ None => panic!("Invalid index {}", key.index),
+ }
+ }
+}
+
+impl<T> IndexMut<Key> for SlabStorage<T> {
+ fn index_mut(&mut self, key: Key) -> &mut Data<T> {
+ let remapped_key = self.remap_key(&key);
+
+ match remapped_key {
+ Some(internal_key) => &mut self.inner[internal_key.index],
+ None => panic!("Invalid index {}", key.index),
+ }
+ }
+}
+
+/// An entry in `DelayQueue` that has expired and been removed.
+///
+/// Values are returned by [`DelayQueue::poll_expired`].
+///
+/// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired
+#[derive(Debug)]
+pub struct Expired<T> {
+ /// The data stored in the queue
+ data: T,
+
+ /// The expiration time
+ deadline: Instant,
+
+ /// The key associated with the entry
+ key: Key,
+}
+
+/// Token to a value stored in a `DelayQueue`.
+///
+/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`]
+/// documentation for more details.
+///
+/// [`DelayQueue`]: struct@DelayQueue
+/// [`DelayQueue::insert`]: method@DelayQueue::insert
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Key {
+ index: usize,
+}
+
+// Whereas `Key` is given out to users that use `DelayQueue`, internally we use
+// `KeyInternal` as the key type in order to make the logic of mapping between keys
+// as a result of `compact` calls clearer.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+struct KeyInternal {
+ index: usize,
+}
+
+#[derive(Debug)]
+struct Stack<T> {
+ /// Head of the stack
+ head: Option<Key>,
+ _p: PhantomData<fn() -> T>,
+}
+
+#[derive(Debug)]
+struct Data<T> {
+ /// The data being stored in the queue and will be returned at the requested
+ /// instant.
+ inner: T,
+
+ /// The instant at which the item is returned.
+ when: u64,
+
+ /// Set to true when stored in the `expired` queue
+ expired: bool,
+
+ /// Next entry in the stack
+ next: Option<Key>,
+
+ /// Previous entry in the stack
+ prev: Option<Key>,
+}
+
+/// Maximum number of entries the queue can handle
+const MAX_ENTRIES: usize = (1 << 30) - 1;
+
+impl<T> DelayQueue<T> {
+ /// Creates a new, empty, `DelayQueue`.
+ ///
+ /// The queue will not allocate storage until items are inserted into it.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use tokio_util::time::DelayQueue;
+ /// let delay_queue: DelayQueue<u32> = DelayQueue::new();
+ /// ```
+ pub fn new() -> DelayQueue<T> {
+ DelayQueue::with_capacity(0)
+ }
+
+ /// Creates a new, empty, `DelayQueue` with the specified capacity.
+ ///
+ /// The queue will be able to hold at least `capacity` elements without
+ /// reallocating. If `capacity` is 0, the queue will not allocate for
+ /// storage.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// # use tokio_util::time::DelayQueue;
+ /// # use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::with_capacity(10);
+ ///
+ /// // These insertions are done without further allocation
+ /// for i in 0..10 {
+ /// delay_queue.insert(i, Duration::from_secs(i));
+ /// }
+ ///
+ /// // This will make the queue allocate additional storage
+ /// delay_queue.insert(11, Duration::from_secs(11));
+ /// # }
+ /// ```
+ pub fn with_capacity(capacity: usize) -> DelayQueue<T> {
+ DelayQueue {
+ wheel: Wheel::new(),
+ slab: SlabStorage::with_capacity(capacity),
+ expired: Stack::default(),
+ delay: None,
+ wheel_now: 0,
+ start: Instant::now(),
+ waker: None,
+ }
+ }
+
+ /// Inserts `value` into the queue set to expire at a specific instant in
+ /// time.
+ ///
+ /// This function is identical to `insert`, but takes an `Instant` instead
+ /// of a `Duration`.
+ ///
+ /// `value` is stored in the queue until `when` is reached. At which point,
+ /// `value` will be returned from [`poll_expired`]. If `when` has already been
+ /// reached, then `value` is immediately made available to poll.
+ ///
+ /// The return value represents the insertion and is used as an argument to
+ /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once
+ /// `value` is removed from the queue either by calling [`poll_expired`] after
+ /// `when` is reached or by calling [`remove`]. At this point, the caller
+ /// must take care to not use the returned [`Key`] again as it may reference
+ /// a different item in the queue.
+ ///
+ /// See [type] level documentation for more details.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `when` is too far in the future.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio::time::{Duration, Instant};
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert_at(
+ /// "foo", Instant::now() + Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ /// [`remove`]: method@Self::remove
+ /// [`reset`]: method@Self::reset
+ /// [`Key`]: struct@Key
+ /// [type]: #
+ #[track_caller]
+ pub fn insert_at(&mut self, value: T, when: Instant) -> Key {
+ assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded");
+
+ // Normalize the deadline. Values cannot be set to expire in the past.
+ let when = self.normalize_deadline(when);
+
+ // Insert the value in the store
+ let key = self.slab.insert(Data {
+ inner: value,
+ when,
+ expired: false,
+ next: None,
+ prev: None,
+ });
+
+ self.insert_idx(when, key);
+
+ // Set a new delay if the current's deadline is later than the one of the new item
+ let should_set_delay = if let Some(ref delay) = self.delay {
+ let current_exp = self.normalize_deadline(delay.deadline());
+ current_exp > when
+ } else {
+ true
+ };
+
+ if should_set_delay {
+ if let Some(waker) = self.waker.take() {
+ waker.wake();
+ }
+
+ let delay_time = self.start + Duration::from_millis(when);
+ if let Some(ref mut delay) = &mut self.delay {
+ delay.as_mut().reset(delay_time);
+ } else {
+ self.delay = Some(Box::pin(sleep_until(delay_time)));
+ }
+ }
+
+ key
+ }
+
+ /// Attempts to pull out the next value of the delay queue, registering the
+ /// current task for wakeup if the value is not yet available, and returning
+ /// `None` if the queue is exhausted.
+ pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Expired<T>>> {
+ if !self
+ .waker
+ .as_ref()
+ .map(|w| w.will_wake(cx.waker()))
+ .unwrap_or(false)
+ {
+ self.waker = Some(cx.waker().clone());
+ }
+
+ let item = ready!(self.poll_idx(cx));
+ Poll::Ready(item.map(|key| {
+ let data = self.slab.remove(&key);
+ debug_assert!(data.next.is_none());
+ debug_assert!(data.prev.is_none());
+
+ Expired {
+ key,
+ data: data.inner,
+ deadline: self.start + Duration::from_millis(data.when),
+ }
+ }))
+ }
+
+ /// Inserts `value` into the queue set to expire after the requested duration
+ /// elapses.
+ ///
+ /// This function is identical to `insert_at`, but takes a `Duration`
+ /// instead of an `Instant`.
+ ///
+ /// `value` is stored in the queue until `timeout` duration has
+ /// elapsed after `insert` was called. At that point, `value` will
+ /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of
+ /// zero, then `value` is immediately made available to poll.
+ ///
+ /// The return value represents the insertion and is used as an
+ /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a
+ /// token and is reused once `value` is removed from the queue
+ /// either by calling [`poll_expired`] after `timeout` has elapsed
+ /// or by calling [`remove`]. At this point, the caller must not
+ /// use the returned [`Key`] again as it may reference a different
+ /// item in the queue.
+ ///
+ /// See [type] level documentation for more details.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `timeout` is greater than the maximum
+ /// duration supported by the timer in the current `Runtime`.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ /// [`remove`]: method@Self::remove
+ /// [`reset`]: method@Self::reset
+ /// [`Key`]: struct@Key
+ /// [type]: #
+ #[track_caller]
+ pub fn insert(&mut self, value: T, timeout: Duration) -> Key {
+ self.insert_at(value, Instant::now() + timeout)
+ }
+
+ #[track_caller]
+ fn insert_idx(&mut self, when: u64, key: Key) {
+ use self::wheel::{InsertError, Stack};
+
+ // Register the deadline with the timer wheel
+ match self.wheel.insert(when, key, &mut self.slab) {
+ Ok(_) => {}
+ Err((_, InsertError::Elapsed)) => {
+ self.slab[key].expired = true;
+ // The delay is already expired, store it in the expired queue
+ self.expired.push(key, &mut self.slab);
+ }
+ Err((_, err)) => panic!("invalid deadline; err={:?}", err),
+ }
+ }
+
+ /// Removes the key from the expired queue or the timer wheel
+ /// depending on its expiration status.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the key is not contained in the expired queue or the wheel.
+ #[track_caller]
+ fn remove_key(&mut self, key: &Key) {
+ use crate::time::wheel::Stack;
+
+ // Special case the `expired` queue
+ if self.slab[*key].expired {
+ self.expired.remove(key, &mut self.slab);
+ } else {
+ self.wheel.remove(key, &mut self.slab);
+ }
+ }
+
+ /// Removes the item associated with `key` from the queue.
+ ///
+ /// There must be an item associated with `key`. The function returns the
+ /// removed item as well as the `Instant` at which it will the delay will
+ /// have expired.
+ ///
+ /// # Panics
+ ///
+ /// The function panics if `key` is not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // Remove the entry
+ /// let item = delay_queue.remove(&key);
+ /// assert_eq!(*item.get_ref(), "foo");
+ /// # }
+ /// ```
+ #[track_caller]
+ pub fn remove(&mut self, key: &Key) -> Expired<T> {
+ let prev_deadline = self.next_deadline();
+
+ self.remove_key(key);
+ let data = self.slab.remove(key);
+
+ let next_deadline = self.next_deadline();
+ if prev_deadline != next_deadline {
+ match (next_deadline, &mut self.delay) {
+ (None, _) => self.delay = None,
+ (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline),
+ (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))),
+ }
+ }
+
+ Expired {
+ key: Key::new(key.index),
+ data: data.inner,
+ deadline: self.start + Duration::from_millis(data.when),
+ }
+ }
+
+ /// Attempts to remove the item associated with `key` from the queue.
+ ///
+ /// Removes the item associated with `key`, and returns it along with the
+ /// `Instant` at which it would have expired, if it exists.
+ ///
+ /// Returns `None` if `key` is not in the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main(flavor = "current_thread")]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // The item is in the queue, `try_remove` returns `Some(Expired("foo"))`.
+ /// let item = delay_queue.try_remove(&key);
+ /// assert_eq!(item.unwrap().into_inner(), "foo");
+ ///
+ /// // The item is not in the queue anymore, `try_remove` returns `None`.
+ /// let item = delay_queue.try_remove(&key);
+ /// assert!(item.is_none());
+ /// # }
+ /// ```
+ pub fn try_remove(&mut self, key: &Key) -> Option<Expired<T>> {
+ if self.slab.contains(key) {
+ Some(self.remove(key))
+ } else {
+ None
+ }
+ }
+
+ /// Sets the delay of the item associated with `key` to expire at `when`.
+ ///
+ /// This function is identical to `reset` but takes an `Instant` instead of
+ /// a `Duration`.
+ ///
+ /// The item remains in the queue but the delay is set to expire at `when`.
+ /// If `when` is in the past, then the item is immediately made available to
+ /// the caller.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `when` is too far in the future or if `key` is
+ /// not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio::time::{Duration, Instant};
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // "foo" is scheduled to be returned in 5 seconds
+ ///
+ /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10));
+ ///
+ /// // "foo" is now scheduled to be returned in 10 seconds
+ /// # }
+ /// ```
+ #[track_caller]
+ pub fn reset_at(&mut self, key: &Key, when: Instant) {
+ self.remove_key(key);
+
+ // Normalize the deadline. Values cannot be set to expire in the past.
+ let when = self.normalize_deadline(when);
+
+ self.slab[*key].when = when;
+ self.slab[*key].expired = false;
+
+ self.insert_idx(when, *key);
+
+ let next_deadline = self.next_deadline();
+ if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) {
+ // This should awaken us if necessary (ie, if already expired)
+ delay.as_mut().reset(deadline);
+ }
+ }
+
+ /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation.
+ /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab
+ /// to the number of elements still contained in it, because elements cannot be moved to a different
+ /// index. To decrease the capacity to the size of the slab use [`compact`].
+ ///
+ /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is
+ /// shrunk in place. Repeated calls run in O(1) though.
+ ///
+ /// [`compact`]: method@Self::compact
+ pub fn shrink_to_fit(&mut self) {
+ self.slab.shrink_to_fit();
+ }
+
+ /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation,
+ /// to the number of elements that are contained in it.
+ ///
+ /// This methods runs in O(n).
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::with_capacity(10);
+ ///
+ /// let key1 = delay_queue.insert(5, Duration::from_secs(5));
+ /// let key2 = delay_queue.insert(10, Duration::from_secs(10));
+ /// let key3 = delay_queue.insert(15, Duration::from_secs(15));
+ ///
+ /// delay_queue.remove(&key2);
+ ///
+ /// delay_queue.compact();
+ /// assert_eq!(delay_queue.capacity(), 2);
+ /// # }
+ /// ```
+ pub fn compact(&mut self) {
+ self.slab.compact();
+ }
+
+ /// Returns the next time to poll as determined by the wheel
+ fn next_deadline(&mut self) -> Option<Instant> {
+ self.wheel
+ .poll_at()
+ .map(|poll_at| self.start + Duration::from_millis(poll_at))
+ }
+
+ /// Sets the delay of the item associated with `key` to expire after
+ /// `timeout`.
+ ///
+ /// This function is identical to `reset_at` but takes a `Duration` instead
+ /// of an `Instant`.
+ ///
+ /// The item remains in the queue but the delay is set to expire after
+ /// `timeout`. If `timeout` is zero, then the item is immediately made
+ /// available to the caller.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if `timeout` is greater than the maximum supported
+ /// duration or if `key` is not contained by the queue.
+ ///
+ /// # Examples
+ ///
+ /// Basic usage
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// let key = delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// // "foo" is scheduled to be returned in 5 seconds
+ ///
+ /// delay_queue.reset(&key, Duration::from_secs(10));
+ ///
+ /// // "foo"is now scheduled to be returned in 10 seconds
+ /// # }
+ /// ```
+ #[track_caller]
+ pub fn reset(&mut self, key: &Key, timeout: Duration) {
+ self.reset_at(key, Instant::now() + timeout);
+ }
+
+ /// Clears the queue, removing all items.
+ ///
+ /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`.
+ ///
+ /// Note that this method has no effect on the allocated capacity.
+ ///
+ /// [`poll_expired`]: method@Self::poll_expired
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ ///
+ /// delay_queue.insert("foo", Duration::from_secs(5));
+ ///
+ /// assert!(!delay_queue.is_empty());
+ ///
+ /// delay_queue.clear();
+ ///
+ /// assert!(delay_queue.is_empty());
+ /// # }
+ /// ```
+ pub fn clear(&mut self) {
+ self.slab.clear();
+ self.expired = Stack::default();
+ self.wheel = Wheel::new();
+ self.delay = None;
+ }
+
+ /// Returns the number of elements the queue can hold without reallocating.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ ///
+ /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
+ /// assert_eq!(delay_queue.capacity(), 10);
+ /// ```
+ pub fn capacity(&self) -> usize {
+ self.slab.capacity()
+ }
+
+ /// Returns the number of elements currently in the queue.
+ ///
+ /// # Examples
+ ///
+ /// ```rust
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
+ /// assert_eq!(delay_queue.len(), 0);
+ /// delay_queue.insert(3, Duration::from_secs(5));
+ /// assert_eq!(delay_queue.len(), 1);
+ /// # }
+ /// ```
+ pub fn len(&self) -> usize {
+ self.slab.len()
+ }
+
+ /// Reserves capacity for at least `additional` more items to be queued
+ /// without allocating.
+ ///
+ /// `reserve` does nothing if the queue already has sufficient capacity for
+ /// `additional` more values. If more capacity is required, a new segment of
+ /// memory will be allocated and all existing values will be copied into it.
+ /// As such, if the queue is already very large, a call to `reserve` can end
+ /// up being expensive.
+ ///
+ /// The queue may reserve more than `additional` extra space in order to
+ /// avoid frequent reallocations.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the new capacity exceeds the maximum number of entries the
+ /// queue can contain.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ ///
+ /// delay_queue.insert("hello", Duration::from_secs(10));
+ /// delay_queue.reserve(10);
+ ///
+ /// assert!(delay_queue.capacity() >= 11);
+ /// # }
+ /// ```
+ #[track_caller]
+ pub fn reserve(&mut self, additional: usize) {
+ assert!(
+ self.slab.capacity() + additional <= MAX_ENTRIES,
+ "max queue capacity exceeded"
+ );
+ self.slab.reserve(additional);
+ }
+
+ /// Returns `true` if there are no items in the queue.
+ ///
+ /// Note that this function returns `false` even if all items have not yet
+ /// expired and a call to `poll` will return `Poll::Pending`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::time::DelayQueue;
+ /// use std::time::Duration;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let mut delay_queue = DelayQueue::new();
+ /// assert!(delay_queue.is_empty());
+ ///
+ /// delay_queue.insert("hello", Duration::from_secs(5));
+ /// assert!(!delay_queue.is_empty());
+ /// # }
+ /// ```
+ pub fn is_empty(&self) -> bool {
+ self.slab.is_empty()
+ }
+
+ /// Polls the queue, returning the index of the next slot in the slab that
+ /// should be returned.
+ ///
+ /// A slot should be returned when the associated deadline has been reached.
+ fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Key>> {
+ use self::wheel::Stack;
+
+ let expired = self.expired.pop(&mut self.slab);
+
+ if expired.is_some() {
+ return Poll::Ready(expired);
+ }
+
+ loop {
+ if let Some(ref mut delay) = self.delay {
+ if !delay.is_elapsed() {
+ ready!(Pin::new(&mut *delay).poll(cx));
+ }
+
+ let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down);
+
+ self.wheel_now = now;
+ }
+
+ // We poll the wheel to get the next value out before finding the next deadline.
+ let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab);
+
+ self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when)));
+
+ if let Some(idx) = wheel_idx {
+ return Poll::Ready(Some(idx));
+ }
+
+ if self.delay.is_none() {
+ return Poll::Ready(None);
+ }
+ }
+ }
+
+ fn normalize_deadline(&self, when: Instant) -> u64 {
+ let when = if when < self.start {
+ 0
+ } else {
+ crate::time::ms(when - self.start, crate::time::Round::Up)
+ };
+
+ cmp::max(when, self.wheel.elapsed())
+ }
+}
+
+// We never put `T` in a `Pin`...
+impl<T> Unpin for DelayQueue<T> {}
+
+impl<T> Default for DelayQueue<T> {
+ fn default() -> DelayQueue<T> {
+ DelayQueue::new()
+ }
+}
+
+impl<T> futures_core::Stream for DelayQueue<T> {
+ // DelayQueue seems much more specific, where a user may care that it
+ // has reached capacity, so return those errors instead of panicking.
+ type Item = Expired<T>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
+ DelayQueue::poll_expired(self.get_mut(), cx)
+ }
+}
+
+impl<T> wheel::Stack for Stack<T> {
+ type Owned = Key;
+ type Borrowed = Key;
+ type Store = SlabStorage<T>;
+
+ fn is_empty(&self) -> bool {
+ self.head.is_none()
+ }
+
+ fn push(&mut self, item: Self::Owned, store: &mut Self::Store) {
+ // Ensure the entry is not already in a stack.
+ debug_assert!(store[item].next.is_none());
+ debug_assert!(store[item].prev.is_none());
+
+ // Remove the old head entry
+ let old = self.head.take();
+
+ if let Some(idx) = old {
+ store[idx].prev = Some(item);
+ }
+
+ store[item].next = old;
+ self.head = Some(item);
+ }
+
+ fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> {
+ if let Some(key) = self.head {
+ self.head = store[key].next;
+
+ if let Some(idx) = self.head {
+ store[idx].prev = None;
+ }
+
+ store[key].next = None;
+ debug_assert!(store[key].prev.is_none());
+
+ Some(key)
+ } else {
+ None
+ }
+ }
+
+ #[track_caller]
+ fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) {
+ let key = *item;
+ assert!(store.contains(item));
+
+ // Ensure that the entry is in fact contained by the stack
+ debug_assert!({
+ // This walks the full linked list even if an entry is found.
+ let mut next = self.head;
+ let mut contains = false;
+
+ while let Some(idx) = next {
+ let data = &store[idx];
+
+ if idx == *item {
+ debug_assert!(!contains);
+ contains = true;
+ }
+
+ next = data.next;
+ }
+
+ contains
+ });
+
+ if let Some(next) = store[key].next {
+ store[next].prev = store[key].prev;
+ }
+
+ if let Some(prev) = store[key].prev {
+ store[prev].next = store[key].next;
+ } else {
+ self.head = store[key].next;
+ }
+
+ store[key].next = None;
+ store[key].prev = None;
+ }
+
+ fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 {
+ store[*item].when
+ }
+}
+
+impl<T> Default for Stack<T> {
+ fn default() -> Stack<T> {
+ Stack {
+ head: None,
+ _p: PhantomData,
+ }
+ }
+}
+
+impl Key {
+ pub(crate) fn new(index: usize) -> Key {
+ Key { index }
+ }
+}
+
+impl KeyInternal {
+ pub(crate) fn new(index: usize) -> KeyInternal {
+ KeyInternal { index }
+ }
+}
+
+impl From<Key> for KeyInternal {
+ fn from(item: Key) -> Self {
+ KeyInternal::new(item.index)
+ }
+}
+
+impl From<KeyInternal> for Key {
+ fn from(item: KeyInternal) -> Self {
+ Key::new(item.index)
+ }
+}
+
+impl<T> Expired<T> {
+ /// Returns a reference to the inner value.
+ pub fn get_ref(&self) -> &T {
+ &self.data
+ }
+
+ /// Returns a mutable reference to the inner value.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.data
+ }
+
+ /// Consumes `self` and returns the inner value.
+ pub fn into_inner(self) -> T {
+ self.data
+ }
+
+ /// Returns the deadline that the expiration was set to.
+ pub fn deadline(&self) -> Instant {
+ self.deadline
+ }
+
+ /// Returns the key that the expiration is indexed by.
+ pub fn key(&self) -> Key {
+ self.key
+ }
+}
diff --git a/src/time/mod.rs b/src/time/mod.rs
new file mode 100644
index 0000000..2d34008
--- /dev/null
+++ b/src/time/mod.rs
@@ -0,0 +1,47 @@
+//! Additional utilities for tracking time.
+//!
+//! This module provides additional utilities for executing code after a set period
+//! of time. Currently there is only one:
+//!
+//! * `DelayQueue`: A queue where items are returned once the requested delay
+//! has expired.
+//!
+//! This type must be used from within the context of the `Runtime`.
+
+use std::time::Duration;
+
+mod wheel;
+
+pub mod delay_queue;
+
+#[doc(inline)]
+pub use delay_queue::DelayQueue;
+
+// ===== Internal utils =====
+
+enum Round {
+ Up,
+ Down,
+}
+
+/// Convert a `Duration` to milliseconds, rounding up and saturating at
+/// `u64::MAX`.
+///
+/// The saturating is fine because `u64::MAX` milliseconds are still many
+/// million years.
+#[inline]
+fn ms(duration: Duration, round: Round) -> u64 {
+ const NANOS_PER_MILLI: u32 = 1_000_000;
+ const MILLIS_PER_SEC: u64 = 1_000;
+
+ // Round up.
+ let millis = match round {
+ Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI,
+ Round::Down => duration.subsec_millis(),
+ };
+
+ duration
+ .as_secs()
+ .saturating_mul(MILLIS_PER_SEC)
+ .saturating_add(u64::from(millis))
+}
diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs
new file mode 100644
index 0000000..8ea30af
--- /dev/null
+++ b/src/time/wheel/level.rs
@@ -0,0 +1,253 @@
+use crate::time::wheel::Stack;
+
+use std::fmt;
+
+/// Wheel for a single level in the timer. This wheel contains 64 slots.
+pub(crate) struct Level<T> {
+ level: usize,
+
+ /// Bit field tracking which slots currently contain entries.
+ ///
+ /// Using a bit field to track slots that contain entries allows avoiding a
+ /// scan to find entries. This field is updated when entries are added or
+ /// removed from a slot.
+ ///
+ /// The least-significant bit represents slot zero.
+ occupied: u64,
+
+ /// Slots
+ slot: [T; LEVEL_MULT],
+}
+
+/// Indicates when a slot must be processed next.
+#[derive(Debug)]
+pub(crate) struct Expiration {
+ /// The level containing the slot.
+ pub(crate) level: usize,
+
+ /// The slot index.
+ pub(crate) slot: usize,
+
+ /// The instant at which the slot needs to be processed.
+ pub(crate) deadline: u64,
+}
+
+/// Level multiplier.
+///
+/// Being a power of 2 is very important.
+const LEVEL_MULT: usize = 64;
+
+impl<T: Stack> Level<T> {
+ pub(crate) fn new(level: usize) -> Level<T> {
+ // Rust's derived implementations for arrays require that the value
+ // contained by the array be `Copy`. So, here we have to manually
+ // initialize every single slot.
+ macro_rules! s {
+ () => {
+ T::default()
+ };
+ }
+
+ Level {
+ level,
+ occupied: 0,
+ slot: [
+ // It does not look like the necessary traits are
+ // derived for [T; 64].
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ s!(),
+ ],
+ }
+ }
+
+ /// Finds the slot that needs to be processed next and returns the slot and
+ /// `Instant` at which this slot must be processed.
+ pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> {
+ // Use the `occupied` bit field to get the index of the next slot that
+ // needs to be processed.
+ let slot = match self.next_occupied_slot(now) {
+ Some(slot) => slot,
+ None => return None,
+ };
+
+ // From the slot index, calculate the `Instant` at which it needs to be
+ // processed. This value *must* be in the future with respect to `now`.
+
+ let level_range = level_range(self.level);
+ let slot_range = slot_range(self.level);
+
+ // TODO: This can probably be simplified w/ power of 2 math
+ let level_start = now - (now % level_range);
+ let deadline = level_start + slot as u64 * slot_range;
+
+ debug_assert!(
+ deadline >= now,
+ "deadline={}; now={}; level={}; slot={}; occupied={:b}",
+ deadline,
+ now,
+ self.level,
+ slot,
+ self.occupied
+ );
+
+ Some(Expiration {
+ level: self.level,
+ slot,
+ deadline,
+ })
+ }
+
+ fn next_occupied_slot(&self, now: u64) -> Option<usize> {
+ if self.occupied == 0 {
+ return None;
+ }
+
+ // Get the slot for now using Maths
+ let now_slot = (now / slot_range(self.level)) as usize;
+ let occupied = self.occupied.rotate_right(now_slot as u32);
+ let zeros = occupied.trailing_zeros() as usize;
+ let slot = (zeros + now_slot) % 64;
+
+ Some(slot)
+ }
+
+ pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) {
+ let slot = slot_for(when, self.level);
+
+ self.slot[slot].push(item, store);
+ self.occupied |= occupied_bit(slot);
+ }
+
+ pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) {
+ let slot = slot_for(when, self.level);
+
+ self.slot[slot].remove(item, store);
+
+ if self.slot[slot].is_empty() {
+ // The bit is currently set
+ debug_assert!(self.occupied & occupied_bit(slot) != 0);
+
+ // Unset the bit
+ self.occupied ^= occupied_bit(slot);
+ }
+ }
+
+ pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> {
+ let ret = self.slot[slot].pop(store);
+
+ if ret.is_some() && self.slot[slot].is_empty() {
+ // The bit is currently set
+ debug_assert!(self.occupied & occupied_bit(slot) != 0);
+
+ self.occupied ^= occupied_bit(slot);
+ }
+
+ ret
+ }
+}
+
+impl<T> fmt::Debug for Level<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Level")
+ .field("occupied", &self.occupied)
+ .finish()
+ }
+}
+
+fn occupied_bit(slot: usize) -> u64 {
+ 1 << slot
+}
+
+fn slot_range(level: usize) -> u64 {
+ LEVEL_MULT.pow(level as u32) as u64
+}
+
+fn level_range(level: usize) -> u64 {
+ LEVEL_MULT as u64 * slot_range(level)
+}
+
+/// Convert a duration (milliseconds) and a level to a slot position
+fn slot_for(duration: u64, level: usize) -> usize {
+ ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize
+}
+
+#[cfg(all(test, not(loom)))]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_slot_for() {
+ for pos in 0..64 {
+ assert_eq!(pos as usize, slot_for(pos, 0));
+ }
+
+ for level in 1..5 {
+ for pos in level..64 {
+ let a = pos * 64_usize.pow(level as u32);
+ assert_eq!(pos as usize, slot_for(a as u64, level));
+ }
+ }
+ }
+}
diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs
new file mode 100644
index 0000000..ffa05ab
--- /dev/null
+++ b/src/time/wheel/mod.rs
@@ -0,0 +1,315 @@
+mod level;
+pub(crate) use self::level::Expiration;
+use self::level::Level;
+
+mod stack;
+pub(crate) use self::stack::Stack;
+
+use std::borrow::Borrow;
+use std::fmt::Debug;
+use std::usize;
+
+/// Timing wheel implementation.
+///
+/// This type provides the hashed timing wheel implementation that backs `Timer`
+/// and `DelayQueue`.
+///
+/// The structure is generic over `T: Stack`. This allows handling timeout data
+/// being stored on the heap or in a slab. In order to support the latter case,
+/// the slab must be passed into each function allowing the implementation to
+/// lookup timer entries.
+///
+/// See `Timer` documentation for some implementation notes.
+#[derive(Debug)]
+pub(crate) struct Wheel<T> {
+ /// The number of milliseconds elapsed since the wheel started.
+ elapsed: u64,
+
+ /// Timer wheel.
+ ///
+ /// Levels:
+ ///
+ /// * 1 ms slots / 64 ms range
+ /// * 64 ms slots / ~ 4 sec range
+ /// * ~ 4 sec slots / ~ 4 min range
+ /// * ~ 4 min slots / ~ 4 hr range
+ /// * ~ 4 hr slots / ~ 12 day range
+ /// * ~ 12 day slots / ~ 2 yr range
+ levels: Vec<Level<T>>,
+}
+
+/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots
+/// each, the timer is able to track time up to 2 years into the future with a
+/// precision of 1 millisecond.
+const NUM_LEVELS: usize = 6;
+
+/// The maximum duration of a delay
+const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1;
+
+#[derive(Debug)]
+pub(crate) enum InsertError {
+ Elapsed,
+ Invalid,
+}
+
+impl<T> Wheel<T>
+where
+ T: Stack,
+{
+ /// Create a new timing wheel
+ pub(crate) fn new() -> Wheel<T> {
+ let levels = (0..NUM_LEVELS).map(Level::new).collect();
+
+ Wheel { elapsed: 0, levels }
+ }
+
+ /// Return the number of milliseconds that have elapsed since the timing
+ /// wheel's creation.
+ pub(crate) fn elapsed(&self) -> u64 {
+ self.elapsed
+ }
+
+ /// Insert an entry into the timing wheel.
+ ///
+ /// # Arguments
+ ///
+ /// * `when`: is the instant at which the entry should be fired. It is
+ /// represented as the number of milliseconds since the creation
+ /// of the timing wheel.
+ ///
+ /// * `item`: The item to insert into the wheel.
+ ///
+ /// * `store`: The slab or `()` when using heap storage.
+ ///
+ /// # Return
+ ///
+ /// Returns `Ok` when the item is successfully inserted, `Err` otherwise.
+ ///
+ /// `Err(Elapsed)` indicates that `when` represents an instant that has
+ /// already passed. In this case, the caller should fire the timeout
+ /// immediately.
+ ///
+ /// `Err(Invalid)` indicates an invalid `when` argument as been supplied.
+ pub(crate) fn insert(
+ &mut self,
+ when: u64,
+ item: T::Owned,
+ store: &mut T::Store,
+ ) -> Result<(), (T::Owned, InsertError)> {
+ if when <= self.elapsed {
+ return Err((item, InsertError::Elapsed));
+ } else if when - self.elapsed > MAX_DURATION {
+ return Err((item, InsertError::Invalid));
+ }
+
+ // Get the level at which the entry should be stored
+ let level = self.level_for(when);
+
+ self.levels[level].add_entry(when, item, store);
+
+ debug_assert!({
+ self.levels[level]
+ .next_expiration(self.elapsed)
+ .map(|e| e.deadline >= self.elapsed)
+ .unwrap_or(true)
+ });
+
+ Ok(())
+ }
+
+ /// Remove `item` from the timing wheel.
+ #[track_caller]
+ pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) {
+ let when = T::when(item, store);
+
+ assert!(
+ self.elapsed <= when,
+ "elapsed={}; when={}",
+ self.elapsed,
+ when
+ );
+
+ let level = self.level_for(when);
+
+ self.levels[level].remove_entry(when, item, store);
+ }
+
+ /// Instant at which to poll
+ pub(crate) fn poll_at(&self) -> Option<u64> {
+ self.next_expiration().map(|expiration| expiration.deadline)
+ }
+
+ /// Advances the timer up to the instant represented by `now`.
+ pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> {
+ loop {
+ let expiration = self.next_expiration().and_then(|expiration| {
+ if expiration.deadline > now {
+ None
+ } else {
+ Some(expiration)
+ }
+ });
+
+ match expiration {
+ Some(ref expiration) => {
+ if let Some(item) = self.poll_expiration(expiration, store) {
+ return Some(item);
+ }
+
+ self.set_elapsed(expiration.deadline);
+ }
+ None => {
+ // in this case the poll did not indicate an expiration
+ // _and_ we were not able to find a next expiration in
+ // the current list of timers. advance to the poll's
+ // current time and do nothing else.
+ self.set_elapsed(now);
+ return None;
+ }
+ }
+ }
+ }
+
+ /// Returns the instant at which the next timeout expires.
+ fn next_expiration(&self) -> Option<Expiration> {
+ // Check all levels
+ for level in 0..NUM_LEVELS {
+ if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) {
+ // There cannot be any expirations at a higher level that happen
+ // before this one.
+ debug_assert!(self.no_expirations_before(level + 1, expiration.deadline));
+
+ return Some(expiration);
+ }
+ }
+
+ None
+ }
+
+ /// Used for debug assertions
+ fn no_expirations_before(&self, start_level: usize, before: u64) -> bool {
+ let mut res = true;
+
+ for l2 in start_level..NUM_LEVELS {
+ if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) {
+ if e2.deadline < before {
+ res = false;
+ }
+ }
+ }
+
+ res
+ }
+
+ /// iteratively find entries that are between the wheel's current
+ /// time and the expiration time. for each in that population either
+ /// return it for notification (in the case of the last level) or tier
+ /// it down to the next level (in all other cases).
+ pub(crate) fn poll_expiration(
+ &mut self,
+ expiration: &Expiration,
+ store: &mut T::Store,
+ ) -> Option<T::Owned> {
+ while let Some(item) = self.pop_entry(expiration, store) {
+ if expiration.level == 0 {
+ debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline);
+
+ return Some(item);
+ } else {
+ let when = T::when(item.borrow(), store);
+
+ let next_level = expiration.level - 1;
+
+ self.levels[next_level].add_entry(when, item, store);
+ }
+ }
+
+ None
+ }
+
+ fn set_elapsed(&mut self, when: u64) {
+ assert!(
+ self.elapsed <= when,
+ "elapsed={:?}; when={:?}",
+ self.elapsed,
+ when
+ );
+
+ if when > self.elapsed {
+ self.elapsed = when;
+ }
+ }
+
+ fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> {
+ self.levels[expiration.level].pop_entry_slot(expiration.slot, store)
+ }
+
+ fn level_for(&self, when: u64) -> usize {
+ level_for(self.elapsed, when)
+ }
+}
+
+fn level_for(elapsed: u64, when: u64) -> usize {
+ const SLOT_MASK: u64 = (1 << 6) - 1;
+
+ // Mask in the trailing bits ignored by the level calculation in order to cap
+ // the possible leading zeros
+ let masked = elapsed ^ when | SLOT_MASK;
+
+ let leading_zeros = masked.leading_zeros() as usize;
+ let significant = 63 - leading_zeros;
+ significant / 6
+}
+
+#[cfg(all(test, not(loom)))]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_level_for() {
+ for pos in 0..64 {
+ assert_eq!(
+ 0,
+ level_for(0, pos),
+ "level_for({}) -- binary = {:b}",
+ pos,
+ pos
+ );
+ }
+
+ for level in 1..5 {
+ for pos in level..64 {
+ let a = pos * 64_usize.pow(level as u32);
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+
+ if pos > level {
+ let a = a - 1;
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+ }
+
+ if pos < 64 {
+ let a = a + 1;
+ assert_eq!(
+ level,
+ level_for(0, a as u64),
+ "level_for({}) -- binary = {:b}",
+ a,
+ a
+ );
+ }
+ }
+ }
+ }
+}
diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs
new file mode 100644
index 0000000..c87adca
--- /dev/null
+++ b/src/time/wheel/stack.rs
@@ -0,0 +1,28 @@
+use std::borrow::Borrow;
+use std::cmp::Eq;
+use std::hash::Hash;
+
+/// Abstracts the stack operations needed to track timeouts.
+pub(crate) trait Stack: Default {
+ /// Type of the item stored in the stack
+ type Owned: Borrow<Self::Borrowed>;
+
+ /// Borrowed item
+ type Borrowed: Eq + Hash;
+
+ /// Item storage, this allows a slab to be used instead of just the heap
+ type Store;
+
+ /// Returns `true` if the stack is empty
+ fn is_empty(&self) -> bool;
+
+ /// Push an item onto the stack
+ fn push(&mut self, item: Self::Owned, store: &mut Self::Store);
+
+ /// Pop an item from the stack
+ fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>;
+
+ fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store);
+
+ fn when(item: &Self::Borrowed, store: &Self::Store) -> u64;
+}
diff --git a/src/udp/frame.rs b/src/udp/frame.rs
new file mode 100644
index 0000000..d094c04
--- /dev/null
+++ b/src/udp/frame.rs
@@ -0,0 +1,249 @@
+use crate::codec::{Decoder, Encoder};
+
+use futures_core::Stream;
+use tokio::{io::ReadBuf, net::UdpSocket};
+
+use bytes::{BufMut, BytesMut};
+use futures_core::ready;
+use futures_sink::Sink;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use std::{
+ borrow::Borrow,
+ net::{Ipv4Addr, SocketAddr, SocketAddrV4},
+};
+use std::{io, mem::MaybeUninit};
+
+/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
+/// the `Encoder` and `Decoder` traits to encode and decode frames.
+///
+/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
+/// batch these into meaningful chunks, called "frames". This method layers
+/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
+/// handle encoding and decoding of messages frames. Note that the incoming and
+/// outgoing frame types may be distinct.
+///
+/// This function returns a *single* object that is both [`Stream`] and [`Sink`];
+/// grouping this into a single object is often useful for layering things which
+/// require both read and write access to the underlying object.
+///
+/// If you want to work more directly with the streams and sink, consider
+/// calling [`split`] on the `UdpFramed` returned by this method, which will break
+/// them into separate objects, allowing them to interact more easily.
+///
+/// [`Stream`]: futures_core::Stream
+/// [`Sink`]: futures_sink::Sink
+/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
+#[must_use = "sinks do nothing unless polled"]
+#[derive(Debug)]
+pub struct UdpFramed<C, T = UdpSocket> {
+ socket: T,
+ codec: C,
+ rd: BytesMut,
+ wr: BytesMut,
+ out_addr: SocketAddr,
+ flushed: bool,
+ is_readable: bool,
+ current_addr: Option<SocketAddr>,
+}
+
+const INITIAL_RD_CAPACITY: usize = 64 * 1024;
+const INITIAL_WR_CAPACITY: usize = 8 * 1024;
+
+impl<C, T> Unpin for UdpFramed<C, T> {}
+
+impl<C, T> Stream for UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+ C: Decoder,
+{
+ type Item = Result<(C::Item, SocketAddr), C::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let pin = self.get_mut();
+
+ pin.rd.reserve(INITIAL_RD_CAPACITY);
+
+ loop {
+ // Are there still bytes left in the read buffer to decode?
+ if pin.is_readable {
+ if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
+ let current_addr = pin
+ .current_addr
+ .expect("will always be set before this line is called");
+
+ return Poll::Ready(Some(Ok((frame, current_addr))));
+ }
+
+ // if this line has been reached then decode has returned `None`.
+ pin.is_readable = false;
+ pin.rd.clear();
+ }
+
+ // We're out of data. Try and fetch more data to decode
+ let addr = {
+ // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
+ // transparent wrapper around `[MaybeUninit<u8>]`.
+ let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut read = ReadBuf::uninit(buf);
+ let ptr = read.filled().as_ptr();
+ let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
+
+ assert_eq!(ptr, read.filled().as_ptr());
+ let addr = res?;
+
+ // Safety: This is guaranteed to be the number of initialized (and read) bytes due
+ // to the invariants provided by `ReadBuf::filled`.
+ unsafe { pin.rd.advance_mut(read.filled().len()) };
+
+ addr
+ };
+
+ pin.current_addr = Some(addr);
+ pin.is_readable = true;
+ }
+ }
+}
+
+impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+ C: Encoder<I>,
+{
+ type Error = C::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if !self.flushed {
+ match self.poll_flush(cx)? {
+ Poll::Ready(()) => {}
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
+ let (frame, out_addr) = item;
+
+ let pin = self.get_mut();
+
+ pin.codec.encode(frame, &mut pin.wr)?;
+ pin.out_addr = out_addr;
+ pin.flushed = false;
+
+ Ok(())
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if self.flushed {
+ return Poll::Ready(Ok(()));
+ }
+
+ let Self {
+ ref socket,
+ ref mut out_addr,
+ ref mut wr,
+ ..
+ } = *self;
+
+ let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
+
+ let wrote_all = n == self.wr.len();
+ self.wr.clear();
+ self.flushed = true;
+
+ let res = if wrote_all {
+ Ok(())
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to write entire datagram to socket",
+ )
+ .into())
+ };
+
+ Poll::Ready(res)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.poll_flush(cx))?;
+ Poll::Ready(Ok(()))
+ }
+}
+
+impl<C, T> UdpFramed<C, T>
+where
+ T: Borrow<UdpSocket>,
+{
+ /// Create a new `UdpFramed` backed by the given socket and codec.
+ ///
+ /// See struct level documentation for more details.
+ pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
+ Self {
+ socket,
+ codec,
+ out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
+ rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
+ wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
+ flushed: true,
+ is_readable: false,
+ current_addr: None,
+ }
+ }
+
+ /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
+ ///
+ /// # Note
+ ///
+ /// Care should be taken to not tamper with the underlying stream of data
+ /// coming in as it may corrupt the stream of frames otherwise being worked
+ /// with.
+ pub fn get_ref(&self) -> &T {
+ &self.socket
+ }
+
+ /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
+ ///
+ /// # Note
+ ///
+ /// Care should be taken to not tamper with the underlying stream of data
+ /// coming in as it may corrupt the stream of frames otherwise being worked
+ /// with.
+ pub fn get_mut(&mut self) -> &mut T {
+ &mut self.socket
+ }
+
+ /// Returns a reference to the underlying codec wrapped by
+ /// `Framed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec(&self) -> &C {
+ &self.codec
+ }
+
+ /// Returns a mutable reference to the underlying codec wrapped by
+ /// `UdpFramed`.
+ ///
+ /// Note that care should be taken to not tamper with the underlying codec
+ /// as it may corrupt the stream of frames otherwise being worked with.
+ pub fn codec_mut(&mut self) -> &mut C {
+ &mut self.codec
+ }
+
+ /// Returns a reference to the read buffer.
+ pub fn read_buffer(&self) -> &BytesMut {
+ &self.rd
+ }
+
+ /// Returns a mutable reference to the read buffer.
+ pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
+ &mut self.rd
+ }
+
+ /// Consumes the `Framed`, returning its underlying I/O stream.
+ pub fn into_inner(self) -> T {
+ self.socket
+ }
+}
diff --git a/src/udp/mod.rs b/src/udp/mod.rs
new file mode 100644
index 0000000..f88ea03
--- /dev/null
+++ b/src/udp/mod.rs
@@ -0,0 +1,4 @@
+//! UDP framing
+
+mod frame;
+pub use frame::UdpFramed;
diff --git a/tests/_require_full.rs b/tests/_require_full.rs
new file mode 100644
index 0000000..045934d
--- /dev/null
+++ b/tests/_require_full.rs
@@ -0,0 +1,2 @@
+#![cfg(not(feature = "full"))]
+compile_error!("run tokio-util tests with `--features full`");
diff --git a/tests/codecs.rs b/tests/codecs.rs
new file mode 100644
index 0000000..f9a7801
--- /dev/null
+++ b/tests/codecs.rs
@@ -0,0 +1,464 @@
+#![warn(rust_2018_idioms)]
+
+use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec};
+
+use bytes::{BufMut, Bytes, BytesMut};
+
+#[test]
+fn bytes_decoder() {
+ let mut codec = BytesCodec::new();
+ let buf = &mut BytesMut::new();
+ buf.put_slice(b"abc");
+ assert_eq!("abc", codec.decode(buf).unwrap().unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"a");
+ assert_eq!("a", codec.decode(buf).unwrap().unwrap());
+}
+
+#[test]
+fn bytes_encoder() {
+ let mut codec = BytesCodec::new();
+
+ // Default capacity of BytesMut
+ #[cfg(target_pointer_width = "64")]
+ const INLINE_CAP: usize = 4 * 8 - 1;
+ #[cfg(target_pointer_width = "32")]
+ const INLINE_CAP: usize = 4 * 4 - 1;
+
+ let mut buf = BytesMut::new();
+ codec
+ .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf)
+ .unwrap();
+
+ // Default capacity of Framed Read
+ const INITIAL_CAPACITY: usize = 8 * 1024;
+
+ let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY);
+ codec
+ .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf)
+ .unwrap();
+ codec
+ .encode(BytesMut::from(&b"hello"[..]), &mut buf)
+ .unwrap();
+}
+
+#[test]
+fn lines_decoder() {
+ let mut codec = LinesCodec::new();
+ let buf = &mut BytesMut::new();
+ buf.reserve(200);
+ buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r");
+ assert_eq!("line 1", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("line 2", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("line 3", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("", codec.decode(buf).unwrap().unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+ buf.put_slice(b"k");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+}
+
+#[test]
+fn lines_decoder_max_length() {
+ const MAX_LENGTH: usize = 6;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r");
+
+ assert!(codec.decode(buf).is_err());
+
+ let line = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ line.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ line,
+ MAX_LENGTH
+ );
+ assert_eq!("line 2", line);
+
+ assert!(codec.decode(buf).is_err());
+
+ let line = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ line.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ line,
+ MAX_LENGTH
+ );
+ assert_eq!("line 4", line);
+
+ let line = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ line.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ line,
+ MAX_LENGTH
+ );
+ assert_eq!("", line);
+
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+ buf.put_slice(b"k");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ let line = codec.decode_eof(buf).unwrap().unwrap();
+ assert!(
+ line.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ line,
+ MAX_LENGTH
+ );
+ assert_eq!("\rk", line);
+
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+
+ // Line that's one character too long. This could cause an out of bounds
+ // error if we peek at the next characters using slice indexing.
+ buf.put_slice(b"aaabbbc");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn lines_decoder_max_length_underrun() {
+ const MAX_LENGTH: usize = 6;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"line ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too l");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"ong\n");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ buf.put_slice(b"line 2");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"\n");
+ assert_eq!("line 2", codec.decode(buf).unwrap().unwrap());
+}
+
+#[test]
+fn lines_decoder_max_length_bursts() {
+ const MAX_LENGTH: usize = 10;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"line ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too l");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"ong\n");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn lines_decoder_max_length_big_burst() {
+ const MAX_LENGTH: usize = 10;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"line ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too long!\n");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn lines_decoder_max_length_newline_between_decodes() {
+ const MAX_LENGTH: usize = 5;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"hello");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ buf.put_slice(b"\nworld");
+ assert_eq!("hello", codec.decode(buf).unwrap().unwrap());
+}
+
+// Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483)
+#[test]
+fn lines_decoder_discard_repeat() {
+ const MAX_LENGTH: usize = 1;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"aa");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"a");
+ assert_eq!(None, codec.decode(buf).unwrap());
+}
+
+// Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555)
+#[test]
+fn lines_decoder_max_length_underrun_twice() {
+ const MAX_LENGTH: usize = 11;
+
+ let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"line ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too very l");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"ong\nshort\n");
+ assert_eq!("short", codec.decode(buf).unwrap().unwrap());
+}
+
+#[test]
+fn lines_encoder() {
+ let mut codec = LinesCodec::new();
+ let mut buf = BytesMut::new();
+
+ codec.encode("line 1", &mut buf).unwrap();
+ assert_eq!("line 1\n", buf);
+
+ codec.encode("line 2", &mut buf).unwrap();
+ assert_eq!("line 1\nline 2\n", buf);
+}
+
+#[test]
+fn any_delimiters_decoder_any_character() {
+ let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec());
+ let buf = &mut BytesMut::new();
+ buf.reserve(200);
+ buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r");
+ assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap());
+ assert_eq!("", codec.decode(buf).unwrap().unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+ buf.put_slice(b"k");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap());
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+}
+
+#[test]
+fn any_delimiters_decoder_max_length() {
+ const MAX_LENGTH: usize = 7;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n");
+
+ assert!(codec.decode(buf).is_err());
+
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("chunk 2", chunk);
+
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("chunk 3", chunk);
+
+ // \r\n cause empty chunk
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("", chunk);
+
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("chunk 4", chunk);
+
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("", chunk);
+
+ let chunk = codec.decode(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("", chunk);
+
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+ buf.put_slice(b"k");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ let chunk = codec.decode_eof(buf).unwrap().unwrap();
+ assert!(
+ chunk.len() <= MAX_LENGTH,
+ "{:?}.len() <= {:?}",
+ chunk,
+ MAX_LENGTH
+ );
+ assert_eq!("k", chunk);
+
+ assert_eq!(None, codec.decode(buf).unwrap());
+ assert_eq!(None, codec.decode_eof(buf).unwrap());
+
+ // Delimiter that's one character too long. This could cause an out of bounds
+ // error if we peek at the next characters using slice indexing.
+ buf.put_slice(b"aaabbbcc");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn any_delimiter_decoder_max_length_underrun() {
+ const MAX_LENGTH: usize = 7;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"chunk ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too l");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"ong\n");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ buf.put_slice(b"chunk 2");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b",");
+ assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
+}
+
+#[test]
+fn any_delimiter_decoder_max_length_underrun_twice() {
+ const MAX_LENGTH: usize = 11;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"chunk ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too very l");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"ong\nshort\n");
+ assert_eq!("short", codec.decode(buf).unwrap().unwrap());
+}
+#[test]
+fn any_delimiter_decoder_max_length_bursts() {
+ const MAX_LENGTH: usize = 11;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"chunk ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too l");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"ong\n");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn any_delimiter_decoder_max_length_big_burst() {
+ const MAX_LENGTH: usize = 11;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"chunk ");
+ assert_eq!(None, codec.decode(buf).unwrap());
+ buf.put_slice(b"too long!\n");
+ assert!(codec.decode(buf).is_err());
+}
+
+#[test]
+fn any_delimiter_decoder_max_length_delimiter_between_decodes() {
+ const MAX_LENGTH: usize = 5;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"hello");
+ assert_eq!(None, codec.decode(buf).unwrap());
+
+ buf.put_slice(b",world");
+ assert_eq!("hello", codec.decode(buf).unwrap().unwrap());
+}
+
+#[test]
+fn any_delimiter_decoder_discard_repeat() {
+ const MAX_LENGTH: usize = 1;
+
+ let mut codec =
+ AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH);
+ let buf = &mut BytesMut::new();
+
+ buf.reserve(200);
+ buf.put_slice(b"aa");
+ assert!(codec.decode(buf).is_err());
+ buf.put_slice(b"a");
+ assert_eq!(None, codec.decode(buf).unwrap());
+}
+
+#[test]
+fn any_delimiter_encoder() {
+ let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec());
+ let mut buf = BytesMut::new();
+
+ codec.encode("chunk 1", &mut buf).unwrap();
+ assert_eq!("chunk 1;--;", buf);
+
+ codec.encode("chunk 2", &mut buf).unwrap();
+ assert_eq!("chunk 1;--;chunk 2;--;", buf);
+}
diff --git a/tests/context.rs b/tests/context.rs
new file mode 100644
index 0000000..2ec8258
--- /dev/null
+++ b/tests/context.rs
@@ -0,0 +1,25 @@
+#![cfg(feature = "rt")]
+#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
+#![warn(rust_2018_idioms)]
+
+use tokio::runtime::Builder;
+use tokio::time::*;
+use tokio_util::context::RuntimeExt;
+
+#[test]
+fn tokio_context_with_another_runtime() {
+ let rt1 = Builder::new_multi_thread()
+ .worker_threads(1)
+ // no timer!
+ .build()
+ .unwrap();
+ let rt2 = Builder::new_multi_thread()
+ .worker_threads(1)
+ .enable_all()
+ .build()
+ .unwrap();
+
+ // Without the `HandleExt.wrap()` there would be a panic because there is
+ // no timer running, since it would be referencing runtime r1.
+ rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await }));
+}
diff --git a/tests/framed.rs b/tests/framed.rs
new file mode 100644
index 0000000..ec8cdf0
--- /dev/null
+++ b/tests/framed.rs
@@ -0,0 +1,152 @@
+#![warn(rust_2018_idioms)]
+
+use tokio_stream::StreamExt;
+use tokio_test::assert_ok;
+use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
+
+use bytes::{Buf, BufMut, BytesMut};
+use std::io::{self, Read};
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+const INITIAL_CAPACITY: usize = 8 * 1024;
+
+/// Encode and decode u32 values.
+#[derive(Default)]
+struct U32Codec {
+ read_bytes: usize,
+}
+
+impl Decoder for U32Codec {
+ type Item = u32;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> {
+ if buf.len() < 4 {
+ return Ok(None);
+ }
+
+ let n = buf.split_to(4).get_u32();
+ self.read_bytes += 4;
+ Ok(Some(n))
+ }
+}
+
+impl Encoder<u32> for U32Codec {
+ type Error = io::Error;
+
+ fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> {
+ // Reserve space
+ dst.reserve(4);
+ dst.put_u32(item);
+ Ok(())
+ }
+}
+
+/// Encode and decode u64 values.
+#[derive(Default)]
+struct U64Codec {
+ read_bytes: usize,
+}
+
+impl Decoder for U64Codec {
+ type Item = u64;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
+ if buf.len() < 8 {
+ return Ok(None);
+ }
+
+ let n = buf.split_to(8).get_u64();
+ self.read_bytes += 8;
+ Ok(Some(n))
+ }
+}
+
+impl Encoder<u64> for U64Codec {
+ type Error = io::Error;
+
+ fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
+ // Reserve space
+ dst.reserve(8);
+ dst.put_u64(item);
+ Ok(())
+ }
+}
+
+/// This value should never be used
+struct DontReadIntoThis;
+
+impl Read for DontReadIntoThis {
+ fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ "Read into something you weren't supposed to.",
+ ))
+ }
+}
+
+impl tokio::io::AsyncRead for DontReadIntoThis {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ _buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ unreachable!()
+ }
+}
+
+#[tokio::test]
+async fn can_read_from_existing_buf() {
+ let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
+ parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);
+
+ let mut framed = Framed::from_parts(parts);
+ let num = assert_ok!(framed.next().await.unwrap());
+
+ assert_eq!(num, 42);
+ assert_eq!(framed.codec().read_bytes, 4);
+}
+
+#[tokio::test]
+async fn can_read_from_existing_buf_after_codec_changed() {
+ let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
+ parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]);
+
+ let mut framed = Framed::from_parts(parts);
+ let num = assert_ok!(framed.next().await.unwrap());
+
+ assert_eq!(num, 42);
+ assert_eq!(framed.codec().read_bytes, 4);
+
+ let mut framed = framed.map_codec(|codec| U64Codec {
+ read_bytes: codec.read_bytes,
+ });
+ let num = assert_ok!(framed.next().await.unwrap());
+
+ assert_eq!(num, 84);
+ assert_eq!(framed.codec().read_bytes, 12);
+}
+
+#[test]
+fn external_buf_grows_to_init() {
+ let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
+ parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]);
+
+ let framed = Framed::from_parts(parts);
+ let FramedParts { read_buf, .. } = framed.into_parts();
+
+ assert_eq!(read_buf.capacity(), INITIAL_CAPACITY);
+}
+
+#[test]
+fn external_buf_does_not_shrink() {
+ let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default());
+ parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]);
+
+ let framed = Framed::from_parts(parts);
+ let FramedParts { read_buf, .. } = framed.into_parts();
+
+ assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2);
+}
diff --git a/tests/framed_read.rs b/tests/framed_read.rs
new file mode 100644
index 0000000..2a9e27e
--- /dev/null
+++ b/tests/framed_read.rs
@@ -0,0 +1,339 @@
+#![warn(rust_2018_idioms)]
+
+use tokio::io::{AsyncRead, ReadBuf};
+use tokio_test::assert_ready;
+use tokio_test::task;
+use tokio_util::codec::{Decoder, FramedRead};
+
+use bytes::{Buf, BytesMut};
+use futures::Stream;
+use std::collections::VecDeque;
+use std::io;
+use std::pin::Pin;
+use std::task::Poll::{Pending, Ready};
+use std::task::{Context, Poll};
+
+macro_rules! mock {
+ ($($x:expr,)*) => {{
+ let mut v = VecDeque::new();
+ v.extend(vec![$($x),*]);
+ Mock { calls: v }
+ }};
+}
+
+macro_rules! assert_read {
+ ($e:expr, $n:expr) => {{
+ let val = assert_ready!($e);
+ assert_eq!(val.unwrap().unwrap(), $n);
+ }};
+}
+
+macro_rules! pin {
+ ($id:ident) => {
+ Pin::new(&mut $id)
+ };
+}
+
+struct U32Decoder;
+
+impl Decoder for U32Decoder {
+ type Item = u32;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> {
+ if buf.len() < 4 {
+ return Ok(None);
+ }
+
+ let n = buf.split_to(4).get_u32();
+ Ok(Some(n))
+ }
+}
+
+struct U64Decoder;
+
+impl Decoder for U64Decoder {
+ type Item = u64;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u64>> {
+ if buf.len() < 8 {
+ return Ok(None);
+ }
+
+ let n = buf.split_to(8).get_u64();
+ Ok(Some(n))
+ }
+}
+
+#[test]
+fn read_multi_frame_in_packet() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert_read!(pin!(framed).poll_next(cx), 2);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_multi_frame_across_packets() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x00".to_vec()),
+ Ok(b"\x00\x00\x00\x01".to_vec()),
+ Ok(b"\x00\x00\x00\x02".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert_read!(pin!(framed).poll_next(cx), 2);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_multi_frame_in_packet_after_codec_changed() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0x04);
+
+ let mut framed = framed.map_decoder(|_| U64Decoder);
+ assert_read!(pin!(framed).poll_next(cx), 0x08);
+
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_not_ready() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "")),
+ Ok(b"\x00\x00\x00\x00".to_vec()),
+ Ok(b"\x00\x00\x00\x01".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert!(pin!(framed).poll_next(cx).is_pending());
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_partial_then_not_ready() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00".to_vec()),
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "")),
+ Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert!(pin!(framed).poll_next(cx).is_pending());
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert_read!(pin!(framed).poll_next(cx), 2);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_err() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Err(io::Error::new(io::ErrorKind::Other, "")),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_eq!(
+ io::ErrorKind::Other,
+ assert_ready!(pin!(framed).poll_next(cx))
+ .unwrap()
+ .unwrap_err()
+ .kind()
+ )
+ });
+}
+
+#[test]
+fn read_partial_then_err() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00".to_vec()),
+ Err(io::Error::new(io::ErrorKind::Other, "")),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_eq!(
+ io::ErrorKind::Other,
+ assert_ready!(pin!(framed).poll_next(cx))
+ .unwrap()
+ .unwrap_err()
+ .kind()
+ )
+ });
+}
+
+#[test]
+fn read_partial_would_block_then_err() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00".to_vec()),
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "")),
+ Err(io::Error::new(io::ErrorKind::Other, "")),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert!(pin!(framed).poll_next(cx).is_pending());
+ assert_eq!(
+ io::ErrorKind::Other,
+ assert_ready!(pin!(framed).poll_next(cx))
+ .unwrap()
+ .unwrap_err()
+ .kind()
+ )
+ });
+}
+
+#[test]
+fn huge_size() {
+ let mut task = task::spawn(());
+ let data = &[0; 32 * 1024][..];
+ let mut framed = FramedRead::new(data, BigDecoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+
+ struct BigDecoder;
+
+ impl Decoder for BigDecoder {
+ type Item = u32;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u32>> {
+ if buf.len() < 32 * 1024 {
+ return Ok(None);
+ }
+ buf.advance(32 * 1024);
+ Ok(Some(0))
+ }
+ }
+}
+
+#[test]
+fn data_remaining_is_error() {
+ let mut task = task::spawn(());
+ let slice = &[0; 5][..];
+ let mut framed = FramedRead::new(slice, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err());
+ });
+}
+
+#[test]
+fn multi_frames_on_eof() {
+ let mut task = task::spawn(());
+ struct MyDecoder(Vec<u32>);
+
+ impl Decoder for MyDecoder {
+ type Item = u32;
+ type Error = io::Error;
+
+ fn decode(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> {
+ unreachable!();
+ }
+
+ fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result<Option<u32>> {
+ if self.0.is_empty() {
+ return Ok(None);
+ }
+
+ Ok(Some(self.0.remove(0)))
+ }
+ }
+
+ let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3]));
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 0);
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert_read!(pin!(framed).poll_next(cx), 2);
+ assert_read!(pin!(framed).poll_next(cx), 3);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+#[test]
+fn read_eof_then_resume() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x01".to_vec()),
+ Ok(b"".to_vec()),
+ Ok(b"\x00\x00\x00\x02".to_vec()),
+ Ok(b"".to_vec()),
+ Ok(b"\x00\x00\x00\x03".to_vec()),
+ };
+ let mut framed = FramedRead::new(mock, U32Decoder);
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(framed).poll_next(cx), 1);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ assert_read!(pin!(framed).poll_next(cx), 2);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ assert_read!(pin!(framed).poll_next(cx), 3);
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none());
+ });
+}
+
+// ===== Mock ======
+
+struct Mock {
+ calls: VecDeque<io::Result<Vec<u8>>>,
+}
+
+impl AsyncRead for Mock {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ use io::ErrorKind::WouldBlock;
+
+ match self.calls.pop_front() {
+ Some(Ok(data)) => {
+ debug_assert!(buf.remaining() >= data.len());
+ buf.put_slice(&data);
+ Ready(Ok(()))
+ }
+ Some(Err(ref e)) if e.kind() == WouldBlock => Pending,
+ Some(Err(e)) => Ready(Err(e)),
+ None => Ready(Ok(())),
+ }
+ }
+}
diff --git a/tests/framed_stream.rs b/tests/framed_stream.rs
new file mode 100644
index 0000000..76d8af7
--- /dev/null
+++ b/tests/framed_stream.rs
@@ -0,0 +1,38 @@
+use futures_core::stream::Stream;
+use std::{io, pin::Pin};
+use tokio_test::{assert_ready, io::Builder, task};
+use tokio_util::codec::{BytesCodec, FramedRead};
+
+macro_rules! pin {
+ ($id:ident) => {
+ Pin::new(&mut $id)
+ };
+}
+
+macro_rules! assert_read {
+ ($e:expr, $n:expr) => {{
+ let val = assert_ready!($e);
+ assert_eq!(val.unwrap().unwrap(), $n);
+ }};
+}
+
+#[tokio::test]
+async fn return_none_after_error() {
+ let mut io = FramedRead::new(
+ Builder::new()
+ .read(b"abcdef")
+ .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out"))
+ .read(b"more data")
+ .build(),
+ BytesCodec::new(),
+ );
+
+ let mut task = task::spawn(());
+
+ task.enter(|cx, _| {
+ assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec());
+ assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err());
+ assert!(assert_ready!(pin!(io).poll_next(cx)).is_none());
+ assert_read!(pin!(io).poll_next(cx), b"more data".to_vec());
+ })
+}
diff --git a/tests/framed_write.rs b/tests/framed_write.rs
new file mode 100644
index 0000000..39091c0
--- /dev/null
+++ b/tests/framed_write.rs
@@ -0,0 +1,212 @@
+#![warn(rust_2018_idioms)]
+
+use tokio::io::AsyncWrite;
+use tokio_test::{assert_ready, task};
+use tokio_util::codec::{Encoder, FramedWrite};
+
+use bytes::{BufMut, BytesMut};
+use futures_sink::Sink;
+use std::collections::VecDeque;
+use std::io::{self, Write};
+use std::pin::Pin;
+use std::task::Poll::{Pending, Ready};
+use std::task::{Context, Poll};
+
+macro_rules! mock {
+ ($($x:expr,)*) => {{
+ let mut v = VecDeque::new();
+ v.extend(vec![$($x),*]);
+ Mock { calls: v }
+ }};
+}
+
+macro_rules! pin {
+ ($id:ident) => {
+ Pin::new(&mut $id)
+ };
+}
+
+struct U32Encoder;
+
+impl Encoder<u32> for U32Encoder {
+ type Error = io::Error;
+
+ fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> {
+ // Reserve space
+ dst.reserve(4);
+ dst.put_u32(item);
+ Ok(())
+ }
+}
+
+struct U64Encoder;
+
+impl Encoder<u64> for U64Encoder {
+ type Error = io::Error;
+
+ fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
+ // Reserve space
+ dst.reserve(8);
+ dst.put_u64(item);
+ Ok(())
+ }
+}
+
+#[test]
+fn write_multi_frame_in_packet() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
+ };
+ let mut framed = FramedWrite::new(mock, U32Encoder);
+
+ task.enter(|cx, _| {
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(0).is_ok());
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(1).is_ok());
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(2).is_ok());
+
+ // Nothing written yet
+ assert_eq!(1, framed.get_ref().calls.len());
+
+ // Flush the writes
+ assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
+
+ assert_eq!(0, framed.get_ref().calls.len());
+ });
+}
+
+#[test]
+fn write_multi_frame_after_codec_changed() {
+ let mut task = task::spawn(());
+ let mock = mock! {
+ Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
+ };
+ let mut framed = FramedWrite::new(mock, U32Encoder);
+
+ task.enter(|cx, _| {
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(0x04).is_ok());
+
+ let mut framed = framed.map_encoder(|_| U64Encoder);
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(0x08).is_ok());
+
+ // Nothing written yet
+ assert_eq!(1, framed.get_ref().calls.len());
+
+ // Flush the writes
+ assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
+
+ assert_eq!(0, framed.get_ref().calls.len());
+ });
+}
+
+#[test]
+fn write_hits_backpressure() {
+ const ITER: usize = 2 * 1024;
+
+ let mut mock = mock! {
+ // Block the `ITER*2`th write
+ Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
+ Ok(b"".to_vec()),
+ };
+
+ for i in 0..=ITER * 2 {
+ let mut b = BytesMut::with_capacity(4);
+ b.put_u32(i as u32);
+
+ // Append to the end
+ match mock.calls.back_mut().unwrap() {
+ Ok(ref mut data) => {
+ // Write in 2kb chunks
+ if data.len() < ITER {
+ data.extend_from_slice(&b[..]);
+ continue;
+ } // else fall through and create a new buffer
+ }
+ _ => unreachable!(),
+ }
+
+ // Push a new chunk
+ mock.calls.push_back(Ok(b[..].to_vec()));
+ }
+ // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer
+ assert_eq!(mock.calls.len(), 10);
+
+ let mut task = task::spawn(());
+ let mut framed = FramedWrite::new(mock, U32Encoder);
+ framed.set_backpressure_boundary(ITER * 8);
+ task.enter(|cx, _| {
+ // Send 16KB. This fills up FramedWrite buffer
+ for i in 0..ITER * 2 {
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+ assert!(pin!(framed).start_send(i as u32).is_ok());
+ }
+
+ // Now we poll_ready which forces a flush. The mock pops the front message
+ // and decides to block.
+ assert!(pin!(framed).poll_ready(cx).is_pending());
+
+ // We poll again, forcing another flush, which this time succeeds
+ // The whole 16KB buffer is flushed
+ assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
+
+ // Send more data. This matches the final message expected by the mock
+ assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok());
+
+ // Flush the rest of the buffer
+ assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
+
+ // Ensure the mock is empty
+ assert_eq!(0, framed.get_ref().calls.len());
+ })
+}
+
+// // ===== Mock ======
+
+struct Mock {
+ calls: VecDeque<io::Result<Vec<u8>>>,
+}
+
+impl Write for Mock {
+ fn write(&mut self, src: &[u8]) -> io::Result<usize> {
+ match self.calls.pop_front() {
+ Some(Ok(data)) => {
+ assert!(src.len() >= data.len());
+ assert_eq!(&data[..], &src[..data.len()]);
+ Ok(data.len())
+ }
+ Some(Err(e)) => Err(e),
+ None => panic!("unexpected write; {:?}", src),
+ }
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ Ok(())
+ }
+}
+
+impl AsyncWrite for Mock {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ match Pin::get_mut(self).write(buf) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
+ other => Ready(other),
+ }
+ }
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ match Pin::get_mut(self).flush() {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
+ other => Ready(other),
+ }
+ }
+ fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ unimplemented!()
+ }
+}
diff --git a/tests/io_inspect.rs b/tests/io_inspect.rs
new file mode 100644
index 0000000..e6319af
--- /dev/null
+++ b/tests/io_inspect.rs
@@ -0,0 +1,194 @@
+use futures::future::poll_fn;
+use std::{
+ io::IoSlice,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
+use tokio_util::io::{InspectReader, InspectWriter};
+
+/// An AsyncRead implementation that works byte-by-byte, to catch out callers
+/// who don't allow for `buf` being part-filled before the call
+struct SmallReader {
+ contents: Vec<u8>,
+}
+
+impl Unpin for SmallReader {}
+
+impl AsyncRead for SmallReader {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ if let Some(byte) = self.contents.pop() {
+ buf.put_slice(&[byte])
+ }
+ Poll::Ready(Ok(()))
+ }
+}
+
+#[tokio::test]
+async fn read_tee() {
+ let contents = b"This could be really long, you know".to_vec();
+ let reader = SmallReader {
+ contents: contents.clone(),
+ };
+ let mut altout: Vec<u8> = Vec::new();
+ let mut teeout = Vec::new();
+ {
+ let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
+ tee.read_to_end(&mut teeout).await.unwrap();
+ }
+ assert_eq!(teeout, altout);
+ assert_eq!(altout.len(), contents.len());
+}
+
+/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
+/// that reads the whole of the first buffer plus one byte from the second in
+/// poll_write_vectored.
+///
+/// This is designed to catch bugs in handling partially written buffers
+#[derive(Debug)]
+struct SmallWriter {
+ contents: Vec<u8>,
+}
+
+impl Unpin for SmallWriter {}
+
+impl AsyncWrite for SmallWriter {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ // Just write one byte at a time
+ if buf.is_empty() {
+ return Poll::Ready(Ok(0));
+ }
+ self.contents.push(buf[0]);
+ Poll::Ready(Ok(1))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_write_vectored(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ bufs: &[IoSlice<'_>],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ // Write all of the first buffer, then one byte from the second buffer
+ // This should trip up anything that doesn't correctly handle multiple
+ // buffers.
+ if bufs.is_empty() {
+ return Poll::Ready(Ok(0));
+ }
+ let mut written_len = bufs[0].len();
+ self.contents.extend_from_slice(&bufs[0]);
+
+ if bufs.len() > 1 {
+ let buf = bufs[1];
+ if !buf.is_empty() {
+ written_len += 1;
+ self.contents.push(buf[0]);
+ }
+ }
+ Poll::Ready(Ok(written_len))
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ true
+ }
+}
+
+#[tokio::test]
+async fn write_tee() {
+ let mut altout: Vec<u8> = Vec::new();
+ let mut writeout = SmallWriter {
+ contents: Vec::new(),
+ };
+ {
+ let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
+ tee.write_all(b"A testing string, very testing")
+ .await
+ .unwrap();
+ }
+ assert_eq!(altout, writeout.contents);
+}
+
+// This is inefficient, but works well enough for test use.
+// If you want something similar for real code, you'll want to avoid all the
+// fun of manipulating `bufs` - ideally, by the time you read this,
+// IoSlice::advance_slices will be stable, and you can use that.
+async fn write_all_vectored<W: AsyncWrite + Unpin>(
+ mut writer: W,
+ mut bufs: Vec<Vec<u8>>,
+) -> Result<usize, std::io::Error> {
+ let mut res = 0;
+ while !bufs.is_empty() {
+ let mut written = poll_fn(|cx| {
+ let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
+ Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
+ })
+ .await?;
+ res += written;
+ while written > 0 {
+ let buf_len = bufs[0].len();
+ if buf_len <= written {
+ bufs.remove(0);
+ written -= buf_len;
+ } else {
+ let buf = &mut bufs[0];
+ let drain_len = written.min(buf.len());
+ buf.drain(..drain_len);
+ written -= drain_len;
+ }
+ }
+ }
+ Ok(res)
+}
+
+#[tokio::test]
+async fn write_tee_vectored() {
+ let mut altout: Vec<u8> = Vec::new();
+ let mut writeout = SmallWriter {
+ contents: Vec::new(),
+ };
+ let original = b"A very long string split up";
+ let bufs: Vec<Vec<u8>> = original
+ .split(|b| b.is_ascii_whitespace())
+ .map(Vec::from)
+ .collect();
+ assert!(bufs.len() > 1);
+ let expected: Vec<u8> = {
+ let mut out = Vec::new();
+ for item in &bufs {
+ out.extend_from_slice(item)
+ }
+ out
+ };
+ {
+ let mut bufcount = 0;
+ let tee = InspectWriter::new(&mut writeout, |bytes| {
+ bufcount += 1;
+ altout.extend(bytes)
+ });
+
+ assert!(tee.is_write_vectored());
+
+ write_all_vectored(tee, bufs.clone()).await.unwrap();
+
+ assert!(bufcount >= bufs.len());
+ }
+ assert_eq!(altout, writeout.contents);
+ assert_eq!(writeout.contents, expected);
+}
diff --git a/tests/io_reader_stream.rs b/tests/io_reader_stream.rs
new file mode 100644
index 0000000..e30cd85
--- /dev/null
+++ b/tests/io_reader_stream.rs
@@ -0,0 +1,65 @@
+#![warn(rust_2018_idioms)]
+
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tokio::io::{AsyncRead, ReadBuf};
+use tokio_stream::StreamExt;
+
+/// produces at most `remaining` zeros, that returns error.
+/// each time it reads at most 31 byte.
+struct Reader {
+ remaining: usize,
+}
+
+impl AsyncRead for Reader {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ let this = Pin::into_inner(self);
+ assert_ne!(buf.remaining(), 0);
+ if this.remaining > 0 {
+ let n = std::cmp::min(this.remaining, buf.remaining());
+ let n = std::cmp::min(n, 31);
+ for x in &mut buf.initialize_unfilled_to(n)[..n] {
+ *x = 0;
+ }
+ buf.advance(n);
+ this.remaining -= n;
+ Poll::Ready(Ok(()))
+ } else {
+ Poll::Ready(Err(std::io::Error::from_raw_os_error(22)))
+ }
+ }
+}
+
+#[tokio::test]
+async fn correct_behavior_on_errors() {
+ let reader = Reader { remaining: 8000 };
+ let mut stream = tokio_util::io::ReaderStream::new(reader);
+ let mut zeros_received = 0;
+ let mut had_error = false;
+ loop {
+ let item = stream.next().await.unwrap();
+ println!("{:?}", item);
+ match item {
+ Ok(bytes) => {
+ let bytes = &*bytes;
+ for byte in bytes {
+ assert_eq!(*byte, 0);
+ zeros_received += 1;
+ }
+ }
+ Err(_) => {
+ assert!(!had_error);
+ had_error = true;
+ break;
+ }
+ }
+ }
+
+ assert!(had_error);
+ assert_eq!(zeros_received, 8000);
+ assert!(stream.next().await.is_none());
+}
diff --git a/tests/io_sink_writer.rs b/tests/io_sink_writer.rs
new file mode 100644
index 0000000..e76870b
--- /dev/null
+++ b/tests/io_sink_writer.rs
@@ -0,0 +1,72 @@
+#![warn(rust_2018_idioms)]
+
+use bytes::Bytes;
+use futures_util::SinkExt;
+use std::io::{self, Error, ErrorKind};
+use tokio::io::AsyncWriteExt;
+use tokio_util::codec::{Encoder, FramedWrite};
+use tokio_util::io::{CopyToBytes, SinkWriter};
+use tokio_util::sync::PollSender;
+
+#[tokio::test]
+async fn test_copied_sink_writer() -> Result<(), Error> {
+ // Construct a channel pair to send data across and wrap a pollable sink.
+ // Note that the sink must mimic a writable object, e.g. have `std::io::Error`
+ // as its error type.
+ // As `PollSender` requires an owned copy of the buffer, we wrap it additionally
+ // with a `CopyToBytes` helper.
+ let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(1);
+ let mut writer = SinkWriter::new(CopyToBytes::new(
+ PollSender::new(tx).sink_map_err(|_| io::Error::from(ErrorKind::BrokenPipe)),
+ ));
+
+ // Write data to our interface...
+ let data: [u8; 4] = [1, 2, 3, 4];
+ let _ = writer.write(&data).await;
+
+ // ... and receive it.
+ assert_eq!(data.to_vec(), rx.recv().await.unwrap().to_vec());
+
+ Ok(())
+}
+
+/// A trivial encoder.
+struct SliceEncoder;
+
+impl SliceEncoder {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+impl<'a> Encoder<&'a [u8]> for SliceEncoder {
+ type Error = Error;
+
+ fn encode(&mut self, item: &'a [u8], dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
+ // This is where we'd write packet headers, lengths, etc. in a real encoder.
+ // For simplicity and demonstration purposes, we just pack a copy of
+ // the slice at the end of a buffer.
+ dst.extend_from_slice(item);
+ Ok(())
+ }
+}
+
+#[tokio::test]
+async fn test_direct_sink_writer() -> Result<(), Error> {
+ // We define a framed writer which accepts byte slices
+ // and 'reverse' this construction immediately.
+ let framed_byte_lc = FramedWrite::new(Vec::new(), SliceEncoder::new());
+ let mut writer = SinkWriter::new(framed_byte_lc);
+
+ // Write multiple slices to the sink...
+ let _ = writer.write(&[1, 2, 3]).await;
+ let _ = writer.write(&[4, 5, 6]).await;
+
+ // ... and compare it with the buffer.
+ assert_eq!(
+ writer.into_inner().write_buffer().to_vec().as_slice(),
+ &[1, 2, 3, 4, 5, 6]
+ );
+
+ Ok(())
+}
diff --git a/tests/io_stream_reader.rs b/tests/io_stream_reader.rs
new file mode 100644
index 0000000..5975994
--- /dev/null
+++ b/tests/io_stream_reader.rs
@@ -0,0 +1,35 @@
+#![warn(rust_2018_idioms)]
+
+use bytes::Bytes;
+use tokio::io::AsyncReadExt;
+use tokio_stream::iter;
+use tokio_util::io::StreamReader;
+
+#[tokio::test]
+async fn test_stream_reader() -> std::io::Result<()> {
+ let stream = iter(vec![
+ std::io::Result::Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[0, 1, 2, 3])),
+ Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[4, 5, 6, 7])),
+ Ok(Bytes::from_static(&[])),
+ Ok(Bytes::from_static(&[8, 9, 10, 11])),
+ Ok(Bytes::from_static(&[])),
+ ]);
+
+ let mut read = StreamReader::new(stream);
+
+ let mut buf = [0; 5];
+ read.read_exact(&mut buf).await?;
+ assert_eq!(buf, [0, 1, 2, 3, 4]);
+
+ assert_eq!(read.read(&mut buf).await?, 3);
+ assert_eq!(&buf[..3], [5, 6, 7]);
+
+ assert_eq!(read.read(&mut buf).await?, 4);
+ assert_eq!(&buf[..4], [8, 9, 10, 11]);
+
+ assert_eq!(read.read(&mut buf).await?, 0);
+
+ Ok(())
+}
diff --git a/tests/io_sync_bridge.rs b/tests/io_sync_bridge.rs
new file mode 100644
index 0000000..76bbd0b
--- /dev/null
+++ b/tests/io_sync_bridge.rs
@@ -0,0 +1,62 @@
+#![cfg(feature = "io-util")]
+#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
+
+use std::error::Error;
+use std::io::{Cursor, Read, Result as IoResult, Write};
+use tokio::io::{AsyncRead, AsyncReadExt};
+use tokio_util::io::SyncIoBridge;
+
+async fn test_reader_len(
+ r: impl AsyncRead + Unpin + Send + 'static,
+ expected_len: usize,
+) -> IoResult<()> {
+ let mut r = SyncIoBridge::new(r);
+ let res = tokio::task::spawn_blocking(move || {
+ let mut buf = Vec::new();
+ r.read_to_end(&mut buf)?;
+ Ok::<_, std::io::Error>(buf)
+ })
+ .await?;
+ assert_eq!(res?.len(), expected_len);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_async_read_to_sync() -> Result<(), Box<dyn Error>> {
+ test_reader_len(tokio::io::empty(), 0).await?;
+ let buf = b"hello world";
+ test_reader_len(Cursor::new(buf), buf.len()).await?;
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_async_write_to_sync() -> Result<(), Box<dyn Error>> {
+ let mut dest = Vec::new();
+ let src = b"hello world";
+ let dest = tokio::task::spawn_blocking(move || -> Result<_, String> {
+ let mut w = SyncIoBridge::new(Cursor::new(&mut dest));
+ std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?;
+ Ok(dest)
+ })
+ .await??;
+ assert_eq!(dest.as_slice(), src);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_shutdown() -> Result<(), Box<dyn Error>> {
+ let (s1, mut s2) = tokio::io::duplex(1024);
+ let (_rh, wh) = tokio::io::split(s1);
+ tokio::task::spawn_blocking(move || -> std::io::Result<_> {
+ let mut wh = SyncIoBridge::new(wh);
+ wh.write_all(b"hello")?;
+ wh.shutdown()?;
+ assert!(wh.write_all(b" world").is_err());
+ Ok(())
+ })
+ .await??;
+ let mut buf = vec![];
+ s2.read_to_end(&mut buf).await?;
+ assert_eq!(buf, b"hello");
+ Ok(())
+}
diff --git a/tests/length_delimited.rs b/tests/length_delimited.rs
new file mode 100644
index 0000000..126e41b
--- /dev/null
+++ b/tests/length_delimited.rs
@@ -0,0 +1,779 @@
+#![warn(rust_2018_idioms)]
+
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio_test::task;
+use tokio_test::{
+ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
+};
+use tokio_util::codec::*;
+
+use bytes::{BufMut, Bytes, BytesMut};
+use futures::{pin_mut, Sink, Stream};
+use std::collections::VecDeque;
+use std::io;
+use std::pin::Pin;
+use std::task::Poll::*;
+use std::task::{Context, Poll};
+
+macro_rules! mock {
+ ($($x:expr,)*) => {{
+ let mut v = VecDeque::new();
+ v.extend(vec![$($x),*]);
+ Mock { calls: v }
+ }};
+}
+
+macro_rules! assert_next_eq {
+ ($io:ident, $expect:expr) => {{
+ task::spawn(()).enter(|cx, _| {
+ let res = assert_ready!($io.as_mut().poll_next(cx));
+ match res {
+ Some(Ok(v)) => assert_eq!(v, $expect.as_ref()),
+ Some(Err(e)) => panic!("error = {:?}", e),
+ None => panic!("none"),
+ }
+ });
+ }};
+}
+
+macro_rules! assert_next_pending {
+ ($io:ident) => {{
+ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) {
+ Ready(Some(Ok(v))) => panic!("value = {:?}", v),
+ Ready(Some(Err(e))) => panic!("error = {:?}", e),
+ Ready(None) => panic!("done"),
+ Pending => {}
+ });
+ }};
+}
+
+macro_rules! assert_next_err {
+ ($io:ident) => {{
+ task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) {
+ Ready(Some(Ok(v))) => panic!("value = {:?}", v),
+ Ready(Some(Err(_))) => {}
+ Ready(None) => panic!("done"),
+ Pending => panic!("pending"),
+ });
+ }};
+}
+
+macro_rules! assert_done {
+ ($io:ident) => {{
+ task::spawn(()).enter(|cx, _| {
+ let res = assert_ready!($io.as_mut().poll_next(cx));
+ match res {
+ Some(Ok(v)) => panic!("value = {:?}", v),
+ Some(Err(e)) => panic!("error = {:?}", e),
+ None => {}
+ }
+ });
+ }};
+}
+
+#[test]
+fn read_empty_io_yields_nothing() {
+ let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new()));
+ pin_mut!(io);
+
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_one_packet() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00\x00\x09abcdefghi"),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_one_packet_little_endian() {
+ let io = length_delimited::Builder::new()
+ .little_endian()
+ .new_read(mock! {
+ data(b"\x09\x00\x00\x00abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_one_packet_native_endian() {
+ let d = if cfg!(target_endian = "big") {
+ b"\x00\x00\x00\x09abcdefghi"
+ } else {
+ b"\x09\x00\x00\x00abcdefghi"
+ };
+ let io = length_delimited::Builder::new()
+ .native_endian()
+ .new_read(mock! {
+ data(d),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_multi_frame_one_packet() {
+ let mut d: Vec<u8> = vec![];
+ d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi");
+ d.extend_from_slice(b"\x00\x00\x00\x03123");
+ d.extend_from_slice(b"\x00\x00\x00\x0bhello world");
+
+ let io = FramedRead::new(
+ mock! {
+ data(&d),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_eq!(io, b"123");
+ assert_next_eq!(io, b"hello world");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_multi_packet() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00"),
+ data(b"\x00\x09abc"),
+ data(b"defghi"),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_multi_frame_multi_packet() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00"),
+ data(b"\x00\x09abc"),
+ data(b"defghi"),
+ data(b"\x00\x00\x00\x0312"),
+ data(b"3\x00\x00\x00\x0bhello world"),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_eq!(io, b"123");
+ assert_next_eq!(io, b"hello world");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_multi_packet_wait() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00"),
+ Pending,
+ data(b"\x00\x09abc"),
+ Pending,
+ data(b"defghi"),
+ Pending,
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_pending!(io);
+ assert_next_pending!(io);
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_pending!(io);
+ assert_done!(io);
+}
+
+#[test]
+fn read_multi_frame_multi_packet_wait() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00"),
+ Pending,
+ data(b"\x00\x09abc"),
+ Pending,
+ data(b"defghi"),
+ Pending,
+ data(b"\x00\x00\x00\x0312"),
+ Pending,
+ data(b"3\x00\x00\x00\x0bhello world"),
+ Pending,
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_pending!(io);
+ assert_next_pending!(io);
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_pending!(io);
+ assert_next_pending!(io);
+ assert_next_eq!(io, b"123");
+ assert_next_eq!(io, b"hello world");
+ assert_next_pending!(io);
+ assert_done!(io);
+}
+
+#[test]
+fn read_incomplete_head() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00"),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_incomplete_head_multi() {
+ let io = FramedRead::new(
+ mock! {
+ Pending,
+ data(b"\x00"),
+ Pending,
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_pending!(io);
+ assert_next_pending!(io);
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_incomplete_payload() {
+ let io = FramedRead::new(
+ mock! {
+ data(b"\x00\x00\x00\x09ab"),
+ Pending,
+ data(b"cd"),
+ Pending,
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ assert_next_pending!(io);
+ assert_next_pending!(io);
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_max_frame_len() {
+ let io = length_delimited::Builder::new()
+ .max_frame_length(5)
+ .new_read(mock! {
+ data(b"\x00\x00\x00\x09abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_update_max_frame_len_at_rest() {
+ let io = length_delimited::Builder::new().new_read(mock! {
+ data(b"\x00\x00\x00\x09abcdefghi"),
+ data(b"\x00\x00\x00\x09abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ io.decoder_mut().set_max_frame_length(5);
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_update_max_frame_len_in_flight() {
+ let io = length_delimited::Builder::new().new_read(mock! {
+ data(b"\x00\x00\x00\x09abcd"),
+ Pending,
+ data(b"efghi"),
+ data(b"\x00\x00\x00\x09abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_pending!(io);
+ io.decoder_mut().set_max_frame_length(5);
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_err!(io);
+}
+
+#[test]
+fn read_one_byte_length_field() {
+ let io = length_delimited::Builder::new()
+ .length_field_length(1)
+ .new_read(mock! {
+ data(b"\x09abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_header_offset() {
+ let io = length_delimited::Builder::new()
+ .length_field_length(2)
+ .length_field_offset(4)
+ .new_read(mock! {
+ data(b"zzzz\x00\x09abcdefghi"),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_multi_frame_one_packet_skip_none_adjusted() {
+ let mut d: Vec<u8> = vec![];
+ d.extend_from_slice(b"xx\x00\x09abcdefghi");
+ d.extend_from_slice(b"yy\x00\x03123");
+ d.extend_from_slice(b"zz\x00\x0bhello world");
+
+ let io = length_delimited::Builder::new()
+ .length_field_length(2)
+ .length_field_offset(2)
+ .num_skip(0)
+ .length_adjustment(4)
+ .new_read(mock! {
+ data(&d),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"xx\x00\x09abcdefghi");
+ assert_next_eq!(io, b"yy\x00\x03123");
+ assert_next_eq!(io, b"zz\x00\x0bhello world");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_frame_length_adjusted() {
+ let mut d: Vec<u8> = vec![];
+ d.extend_from_slice(b"\x00\x00\x0b\x0cHello world");
+
+ let io = length_delimited::Builder::new()
+ .length_field_offset(0)
+ .length_field_length(3)
+ .length_adjustment(0)
+ .num_skip(4)
+ .new_read(mock! {
+ data(&d),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"Hello world");
+ assert_done!(io);
+}
+
+#[test]
+fn read_single_multi_frame_one_packet_length_includes_head() {
+ let mut d: Vec<u8> = vec![];
+ d.extend_from_slice(b"\x00\x0babcdefghi");
+ d.extend_from_slice(b"\x00\x05123");
+ d.extend_from_slice(b"\x00\x0dhello world");
+
+ let io = length_delimited::Builder::new()
+ .length_field_length(2)
+ .length_adjustment(-2)
+ .new_read(mock! {
+ data(&d),
+ });
+ pin_mut!(io);
+
+ assert_next_eq!(io, b"abcdefghi");
+ assert_next_eq!(io, b"123");
+ assert_next_eq!(io, b"hello world");
+ assert_done!(io);
+}
+
+#[test]
+fn write_single_frame_length_adjusted() {
+ let io = length_delimited::Builder::new()
+ .length_adjustment(-2)
+ .new_write(mock! {
+ data(b"\x00\x00\x00\x0b"),
+ data(b"abcdefghi"),
+ flush(),
+ });
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_nothing_yields_nothing() {
+ let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new());
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.poll_flush(cx));
+ });
+}
+
+#[test]
+fn write_single_frame_one_packet() {
+ let io = FramedWrite::new(
+ mock! {
+ data(b"\x00\x00\x00\x09"),
+ data(b"abcdefghi"),
+ flush(),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_single_multi_frame_one_packet() {
+ let io = FramedWrite::new(
+ mock! {
+ data(b"\x00\x00\x00\x09"),
+ data(b"abcdefghi"),
+ data(b"\x00\x00\x00\x03"),
+ data(b"123"),
+ data(b"\x00\x00\x00\x0b"),
+ data(b"hello world"),
+ flush(),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("123")));
+
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("hello world")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_single_multi_frame_multi_packet() {
+ let io = FramedWrite::new(
+ mock! {
+ data(b"\x00\x00\x00\x09"),
+ data(b"abcdefghi"),
+ flush(),
+ data(b"\x00\x00\x00\x03"),
+ data(b"123"),
+ flush(),
+ data(b"\x00\x00\x00\x0b"),
+ data(b"hello world"),
+ flush(),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("123")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("hello world")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_single_frame_would_block() {
+ let io = FramedWrite::new(
+ mock! {
+ Pending,
+ data(b"\x00\x00"),
+ Pending,
+ data(b"\x00\x09"),
+ data(b"abcdefghi"),
+ flush(),
+ },
+ LengthDelimitedCodec::new(),
+ );
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+
+ assert_pending!(io.as_mut().poll_flush(cx));
+ assert_pending!(io.as_mut().poll_flush(cx));
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_single_frame_little_endian() {
+ let io = length_delimited::Builder::new()
+ .little_endian()
+ .new_write(mock! {
+ data(b"\x09\x00\x00\x00"),
+ data(b"abcdefghi"),
+ flush(),
+ });
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_single_frame_with_short_length_field() {
+ let io = length_delimited::Builder::new()
+ .length_field_length(1)
+ .new_write(mock! {
+ data(b"\x09"),
+ data(b"abcdefghi"),
+ flush(),
+ });
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_max_frame_len() {
+ let io = length_delimited::Builder::new()
+ .max_frame_length(5)
+ .new_write(mock! {});
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_err!(io.as_mut().start_send(Bytes::from("abcdef")));
+
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_update_max_frame_len_at_rest() {
+ let io = length_delimited::Builder::new().new_write(mock! {
+ data(b"\x00\x00\x00\x06"),
+ data(b"abcdef"),
+ flush(),
+ });
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdef")));
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ io.encoder_mut().set_max_frame_length(5);
+
+ assert_err!(io.as_mut().start_send(Bytes::from("abcdef")));
+
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_update_max_frame_len_in_flight() {
+ let io = length_delimited::Builder::new().new_write(mock! {
+ data(b"\x00\x00\x00\x06"),
+ data(b"ab"),
+ Pending,
+ data(b"cdef"),
+ flush(),
+ });
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdef")));
+
+ assert_pending!(io.as_mut().poll_flush(cx));
+
+ io.encoder_mut().set_max_frame_length(5);
+
+ assert_ready_ok!(io.as_mut().poll_flush(cx));
+
+ assert_err!(io.as_mut().start_send(Bytes::from("abcdef")));
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn write_zero() {
+ let io = length_delimited::Builder::new().new_write(mock! {});
+ pin_mut!(io);
+
+ task::spawn(()).enter(|cx, _| {
+ assert_ready_ok!(io.as_mut().poll_ready(cx));
+ assert_ok!(io.as_mut().start_send(Bytes::from("abcdef")));
+
+ assert_ready_err!(io.as_mut().poll_flush(cx));
+
+ assert!(io.get_ref().calls.is_empty());
+ });
+}
+
+#[test]
+fn encode_overflow() {
+ // Test reproducing tokio-rs/tokio#681.
+ let mut codec = length_delimited::Builder::new().new_codec();
+ let mut buf = BytesMut::with_capacity(1024);
+
+ // Put some data into the buffer without resizing it to hold more.
+ let some_as = std::iter::repeat(b'a').take(1024).collect::<Vec<_>>();
+ buf.put_slice(&some_as[..]);
+
+ // Trying to encode the length header should resize the buffer if it won't fit.
+ codec.encode(Bytes::from("hello"), &mut buf).unwrap();
+}
+
+// ===== Test utils =====
+
+struct Mock {
+ calls: VecDeque<Poll<io::Result<Op>>>,
+}
+
+enum Op {
+ Data(Vec<u8>),
+ Flush,
+}
+
+use self::Op::*;
+
+impl AsyncRead for Mock {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ match self.calls.pop_front() {
+ Some(Ready(Ok(Op::Data(data)))) => {
+ debug_assert!(dst.remaining() >= data.len());
+ dst.put_slice(&data);
+ Ready(Ok(()))
+ }
+ Some(Ready(Ok(_))) => panic!(),
+ Some(Ready(Err(e))) => Ready(Err(e)),
+ Some(Pending) => Pending,
+ None => Ready(Ok(())),
+ }
+ }
+}
+
+impl AsyncWrite for Mock {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ src: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ match self.calls.pop_front() {
+ Some(Ready(Ok(Op::Data(data)))) => {
+ let len = data.len();
+ assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src);
+ assert_eq!(&data[..], &src[..len]);
+ Ready(Ok(len))
+ }
+ Some(Ready(Ok(_))) => panic!(),
+ Some(Ready(Err(e))) => Ready(Err(e)),
+ Some(Pending) => Pending,
+ None => Ready(Ok(0)),
+ }
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ match self.calls.pop_front() {
+ Some(Ready(Ok(Op::Flush))) => Ready(Ok(())),
+ Some(Ready(Ok(_))) => panic!(),
+ Some(Ready(Err(e))) => Ready(Err(e)),
+ Some(Pending) => Pending,
+ None => Ready(Ok(())),
+ }
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Ready(Ok(()))
+ }
+}
+
+impl<'a> From<&'a [u8]> for Op {
+ fn from(src: &'a [u8]) -> Op {
+ Op::Data(src.into())
+ }
+}
+
+impl From<Vec<u8>> for Op {
+ fn from(src: Vec<u8>) -> Op {
+ Op::Data(src)
+ }
+}
+
+fn data(bytes: &[u8]) -> Poll<io::Result<Op>> {
+ Ready(Ok(bytes.into()))
+}
+
+fn flush() -> Poll<io::Result<Op>> {
+ Ready(Ok(Flush))
+}
diff --git a/tests/mpsc.rs b/tests/mpsc.rs
new file mode 100644
index 0000000..a3c164d
--- /dev/null
+++ b/tests/mpsc.rs
@@ -0,0 +1,239 @@
+use futures::future::poll_fn;
+use tokio::sync::mpsc::channel;
+use tokio_test::task::spawn;
+use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok};
+use tokio_util::sync::PollSender;
+
+#[tokio::test]
+async fn simple() {
+ let (send, mut recv) = channel(3);
+ let mut send = PollSender::new(send);
+
+ for i in 1..=3i32 {
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(i).unwrap();
+ }
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+
+ assert_eq!(recv.recv().await.unwrap(), 1);
+ assert!(reserve.is_woken());
+ assert_ready_ok!(reserve.poll());
+
+ drop(recv);
+
+ send.send_item(42).unwrap();
+}
+
+#[tokio::test]
+async fn repeated_poll_reserve() {
+ let (send, mut recv) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ assert_ready_ok!(reserve.poll());
+ send.send_item(1).unwrap();
+
+ assert_eq!(recv.recv().await.unwrap(), 1);
+}
+
+#[tokio::test]
+async fn abort_send() {
+ let (send, mut recv) = channel(3);
+ let mut send = PollSender::new(send);
+ let send2 = send.get_ref().cloned().unwrap();
+
+ for i in 1..=3i32 {
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(i).unwrap();
+ }
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+ assert_eq!(recv.recv().await.unwrap(), 1);
+ assert!(reserve.is_woken());
+ assert_ready_ok!(reserve.poll());
+
+ let mut send2_send = spawn(send2.send(5));
+ assert_pending!(send2_send.poll());
+ assert!(send.abort_send());
+ assert!(send2_send.is_woken());
+ assert_ready_ok!(send2_send.poll());
+
+ assert_eq!(recv.recv().await.unwrap(), 2);
+ assert_eq!(recv.recv().await.unwrap(), 3);
+ assert_eq!(recv.recv().await.unwrap(), 5);
+}
+
+#[tokio::test]
+async fn close_sender_last() {
+ let (send, mut recv) = channel::<i32>(3);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ send.close();
+
+ assert!(recv_task.is_woken());
+ assert!(assert_ready!(recv_task.poll()).is_none());
+}
+
+#[tokio::test]
+async fn close_sender_not_last() {
+ let (send, mut recv) = channel::<i32>(3);
+ let mut send = PollSender::new(send);
+ let send2 = send.get_ref().cloned().unwrap();
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ send.close();
+
+ assert!(!recv_task.is_woken());
+ assert_pending!(recv_task.poll());
+
+ drop(send2);
+
+ assert!(recv_task.is_woken());
+ assert!(assert_ready!(recv_task.poll()).is_none());
+}
+
+#[tokio::test]
+async fn close_sender_before_reserve() {
+ let (send, mut recv) = channel::<i32>(3);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ send.close();
+
+ assert!(recv_task.is_woken());
+ assert!(assert_ready!(recv_task.poll()).is_none());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_err!(reserve.poll());
+}
+
+#[tokio::test]
+async fn close_sender_after_pending_reserve() {
+ let (send, mut recv) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(1).unwrap();
+
+ assert!(recv_task.is_woken());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+ drop(reserve);
+
+ send.close();
+
+ assert!(send.is_closed());
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_err!(reserve.poll());
+}
+
+#[tokio::test]
+async fn close_sender_after_successful_reserve() {
+ let (send, mut recv) = channel::<i32>(3);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ drop(reserve);
+
+ send.close();
+ assert!(send.is_closed());
+ assert!(!recv_task.is_woken());
+ assert_pending!(recv_task.poll());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+}
+
+#[tokio::test]
+async fn abort_send_after_pending_reserve() {
+ let (send, mut recv) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(1).unwrap();
+
+ assert_eq!(send.get_ref().unwrap().capacity(), 0);
+ assert!(!send.abort_send());
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+
+ assert!(send.abort_send());
+ assert_eq!(send.get_ref().unwrap().capacity(), 0);
+}
+
+#[tokio::test]
+async fn abort_send_after_successful_reserve() {
+ let (send, mut recv) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut recv_task = spawn(recv.recv());
+ assert_pending!(recv_task.poll());
+
+ assert_eq!(send.get_ref().unwrap().capacity(), 1);
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ assert_eq!(send.get_ref().unwrap().capacity(), 0);
+
+ assert!(send.abort_send());
+ assert_eq!(send.get_ref().unwrap().capacity(), 1);
+}
+
+#[tokio::test]
+async fn closed_when_receiver_drops() {
+ let (send, _) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_err!(reserve.poll());
+}
+
+#[should_panic]
+#[test]
+fn start_send_panics_when_idle() {
+ let (send, _) = channel::<i32>(3);
+ let mut send = PollSender::new(send);
+
+ send.send_item(1).unwrap();
+}
+
+#[should_panic]
+#[test]
+fn start_send_panics_when_acquiring() {
+ let (send, _) = channel::<i32>(1);
+ let mut send = PollSender::new(send);
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_ready_ok!(reserve.poll());
+ send.send_item(1).unwrap();
+
+ let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx)));
+ assert_pending!(reserve.poll());
+ send.send_item(2).unwrap();
+}
diff --git a/tests/panic.rs b/tests/panic.rs
new file mode 100644
index 0000000..e4fcb47
--- /dev/null
+++ b/tests/panic.rs
@@ -0,0 +1,226 @@
+#![warn(rust_2018_idioms)]
+#![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi doesn't support panic recovery
+
+use parking_lot::{const_mutex, Mutex};
+use std::error::Error;
+use std::panic;
+use std::sync::Arc;
+use tokio::runtime::Runtime;
+use tokio::sync::mpsc::channel;
+use tokio::time::{Duration, Instant};
+use tokio_test::task;
+use tokio_util::io::SyncIoBridge;
+use tokio_util::sync::PollSender;
+use tokio_util::task::LocalPoolHandle;
+use tokio_util::time::DelayQueue;
+
+// Taken from tokio-util::time::wheel, if that changes then
+const MAX_DURATION_MS: u64 = (1 << (36)) - 1;
+
+fn test_panic<Func: FnOnce() + panic::UnwindSafe>(func: Func) -> Option<String> {
+ static PANIC_MUTEX: Mutex<()> = const_mutex(());
+
+ {
+ let _guard = PANIC_MUTEX.lock();
+ let panic_file: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
+
+ let prev_hook = panic::take_hook();
+ {
+ let panic_file = panic_file.clone();
+ panic::set_hook(Box::new(move |panic_info| {
+ let panic_location = panic_info.location().unwrap();
+ panic_file
+ .lock()
+ .clone_from(&Some(panic_location.file().to_string()));
+ }));
+ }
+
+ let result = panic::catch_unwind(func);
+ // Return to the previously set panic hook (maybe default) so that we get nice error
+ // messages in the tests.
+ panic::set_hook(prev_hook);
+
+ if result.is_err() {
+ panic_file.lock().clone()
+ } else {
+ None
+ }
+ }
+}
+
+#[test]
+fn sync_bridge_new_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let _ = SyncIoBridge::new(tokio::io::empty());
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn poll_sender_send_item_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let (send, _) = channel::<u32>(3);
+ let mut send = PollSender::new(send);
+
+ let _ = send.send_item(42);
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+
+fn local_pool_handle_new_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let _ = LocalPoolHandle::new(0);
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+
+fn local_pool_handle_spawn_pinned_by_idx_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+
+ rt.block_on(async {
+ let handle = LocalPoolHandle::new(2);
+ handle.spawn_pinned_by_idx(|| async { "test" }, 3);
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+#[test]
+fn delay_queue_insert_at_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::with_capacity(3));
+
+ //let st = std::time::Instant::from(SystemTime::UNIX_EPOCH);
+ let _k = queue.insert_at(
+ "1",
+ Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1),
+ );
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn delay_queue_insert_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::with_capacity(3));
+
+ let _k = queue.insert("1", Duration::from_millis(MAX_DURATION_MS + 1));
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn delay_queue_remove_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::with_capacity(3));
+
+ let key = queue.insert_at("1", Instant::now());
+ queue.remove(&key);
+ queue.remove(&key);
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn delay_queue_reset_at_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::with_capacity(3));
+
+ let key = queue.insert_at("1", Instant::now());
+ queue.reset_at(
+ &key,
+ Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1),
+ );
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn delay_queue_reset_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::with_capacity(3));
+
+ let key = queue.insert_at("1", Instant::now());
+ queue.reset(&key, Duration::from_millis(MAX_DURATION_MS + 1));
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+#[test]
+fn delay_queue_reserve_panic_caller() -> Result<(), Box<dyn Error>> {
+ let panic_location_file = test_panic(|| {
+ let rt = basic();
+ rt.block_on(async {
+ let mut queue = task::spawn(DelayQueue::<u32>::with_capacity(3));
+
+ queue.reserve((1 << 30) as usize);
+ });
+ });
+
+ // The panic location should be in this file
+ assert_eq!(&panic_location_file.unwrap(), file!());
+
+ Ok(())
+}
+
+fn basic() -> Runtime {
+ tokio::runtime::Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .unwrap()
+}
diff --git a/tests/poll_semaphore.rs b/tests/poll_semaphore.rs
new file mode 100644
index 0000000..28beca1
--- /dev/null
+++ b/tests/poll_semaphore.rs
@@ -0,0 +1,84 @@
+use std::future::Future;
+use std::sync::Arc;
+use std::task::Poll;
+use tokio::sync::{OwnedSemaphorePermit, Semaphore};
+use tokio_util::sync::PollSemaphore;
+
+type SemRet = Option<OwnedSemaphorePermit>;
+
+fn semaphore_poll(
+ sem: &mut PollSemaphore,
+) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> {
+ let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx));
+ tokio_test::task::spawn(fut)
+}
+
+fn semaphore_poll_many(
+ sem: &mut PollSemaphore,
+ permits: u32,
+) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> {
+ let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits));
+ tokio_test::task::spawn(fut)
+}
+
+#[tokio::test]
+async fn it_works() {
+ let sem = Arc::new(Semaphore::new(1));
+ let mut poll_sem = PollSemaphore::new(sem.clone());
+
+ let permit = sem.acquire().await.unwrap();
+ let mut poll = semaphore_poll(&mut poll_sem);
+ assert!(poll.poll().is_pending());
+ drop(permit);
+
+ assert!(matches!(poll.poll(), Poll::Ready(Some(_))));
+ drop(poll);
+
+ sem.close();
+
+ assert!(semaphore_poll(&mut poll_sem).await.is_none());
+
+ // Check that it is fused.
+ assert!(semaphore_poll(&mut poll_sem).await.is_none());
+ assert!(semaphore_poll(&mut poll_sem).await.is_none());
+}
+
+#[tokio::test]
+async fn can_acquire_many_permits() {
+ let sem = Arc::new(Semaphore::new(4));
+ let mut poll_sem = PollSemaphore::new(sem.clone());
+
+ let permit1 = semaphore_poll(&mut poll_sem).poll();
+ assert!(matches!(permit1, Poll::Ready(Some(_))));
+
+ let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll();
+ assert!(matches!(permit2, Poll::Ready(Some(_))));
+
+ assert_eq!(sem.available_permits(), 1);
+
+ drop(permit2);
+
+ let mut permit4 = semaphore_poll_many(&mut poll_sem, 4);
+ assert!(permit4.poll().is_pending());
+
+ drop(permit1);
+
+ let permit4 = permit4.poll();
+ assert!(matches!(permit4, Poll::Ready(Some(_))));
+ assert_eq!(sem.available_permits(), 0);
+}
+
+#[tokio::test]
+async fn can_poll_different_amounts_of_permits() {
+ let sem = Arc::new(Semaphore::new(4));
+ let mut poll_sem = PollSemaphore::new(sem.clone());
+ assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
+ assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());
+
+ let permit = sem.acquire_many(4).await.unwrap();
+ assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
+ assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending());
+ drop(permit);
+ assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
+ assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());
+}
diff --git a/tests/reusable_box.rs b/tests/reusable_box.rs
new file mode 100644
index 0000000..e9c8a25
--- /dev/null
+++ b/tests/reusable_box.rs
@@ -0,0 +1,85 @@
+use futures::future::FutureExt;
+use std::alloc::Layout;
+use std::future::Future;
+use std::marker::PhantomPinned;
+use std::pin::Pin;
+use std::rc::Rc;
+use std::task::{Context, Poll};
+use tokio_util::sync::ReusableBoxFuture;
+
+#[test]
+// Clippy false positive; it's useful to be able to test the trait impls for any lifetime
+#[allow(clippy::extra_unused_lifetimes)]
+fn traits<'a>() {
+ fn assert_traits<T: Send + Sync + Unpin>() {}
+ // Use a type that is !Unpin
+ assert_traits::<ReusableBoxFuture<'a, PhantomPinned>>();
+ // Use a type that is !Send + !Sync
+ assert_traits::<ReusableBoxFuture<'a, Rc<()>>>();
+}
+
+#[test]
+fn test_different_futures() {
+ let fut = async move { 10 };
+ // Not zero sized!
+ assert_eq!(Layout::for_value(&fut).size(), 1);
+
+ let mut b = ReusableBoxFuture::new(fut);
+
+ assert_eq!(b.get_pin().now_or_never(), Some(10));
+
+ b.try_set(async move { 20 })
+ .unwrap_or_else(|_| panic!("incorrect size"));
+
+ assert_eq!(b.get_pin().now_or_never(), Some(20));
+
+ b.try_set(async move { 30 })
+ .unwrap_or_else(|_| panic!("incorrect size"));
+
+ assert_eq!(b.get_pin().now_or_never(), Some(30));
+}
+
+#[test]
+fn test_different_sizes() {
+ let fut1 = async move { 10 };
+ let val = [0u32; 1000];
+ let fut2 = async move { val[0] };
+ let fut3 = ZeroSizedFuture {};
+
+ assert_eq!(Layout::for_value(&fut1).size(), 1);
+ assert_eq!(Layout::for_value(&fut2).size(), 4004);
+ assert_eq!(Layout::for_value(&fut3).size(), 0);
+
+ let mut b = ReusableBoxFuture::new(fut1);
+ assert_eq!(b.get_pin().now_or_never(), Some(10));
+ b.set(fut2);
+ assert_eq!(b.get_pin().now_or_never(), Some(0));
+ b.set(fut3);
+ assert_eq!(b.get_pin().now_or_never(), Some(5));
+}
+
+struct ZeroSizedFuture {}
+impl Future for ZeroSizedFuture {
+ type Output = u32;
+ fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u32> {
+ Poll::Ready(5)
+ }
+}
+
+#[test]
+fn test_zero_sized() {
+ let fut = ZeroSizedFuture {};
+ // Zero sized!
+ assert_eq!(Layout::for_value(&fut).size(), 0);
+
+ let mut b = ReusableBoxFuture::new(fut);
+
+ assert_eq!(b.get_pin().now_or_never(), Some(5));
+ assert_eq!(b.get_pin().now_or_never(), Some(5));
+
+ b.try_set(ZeroSizedFuture {})
+ .unwrap_or_else(|_| panic!("incorrect size"));
+
+ assert_eq!(b.get_pin().now_or_never(), Some(5));
+ assert_eq!(b.get_pin().now_or_never(), Some(5));
+}
diff --git a/tests/spawn_pinned.rs b/tests/spawn_pinned.rs
new file mode 100644
index 0000000..9ea8cd2
--- /dev/null
+++ b/tests/spawn_pinned.rs
@@ -0,0 +1,237 @@
+#![warn(rust_2018_idioms)]
+#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
+
+use std::rc::Rc;
+use std::sync::Arc;
+use tokio::sync::Barrier;
+use tokio_util::task;
+
+/// Simple test of running a !Send future via spawn_pinned
+#[tokio::test]
+async fn can_spawn_not_send_future() {
+ let pool = task::LocalPoolHandle::new(1);
+
+ let output = pool
+ .spawn_pinned(|| {
+ // Rc is !Send + !Sync
+ let local_data = Rc::new("test");
+
+ // This future holds an Rc, so it is !Send
+ async move { local_data.to_string() }
+ })
+ .await
+ .unwrap();
+
+ assert_eq!(output, "test");
+}
+
+/// Dropping the join handle still lets the task execute
+#[test]
+fn can_drop_future_and_still_get_output() {
+ let pool = task::LocalPoolHandle::new(1);
+ let (sender, receiver) = std::sync::mpsc::channel();
+
+ let _ = pool.spawn_pinned(move || {
+ // Rc is !Send + !Sync
+ let local_data = Rc::new("test");
+
+ // This future holds an Rc, so it is !Send
+ async move {
+ let _ = sender.send(local_data.to_string());
+ }
+ });
+
+ assert_eq!(receiver.recv(), Ok("test".to_string()));
+}
+
+#[test]
+#[should_panic(expected = "assertion failed: pool_size > 0")]
+fn cannot_create_zero_sized_pool() {
+ let _pool = task::LocalPoolHandle::new(0);
+}
+
+/// We should be able to spawn multiple futures onto the pool at the same time.
+#[tokio::test]
+async fn can_spawn_multiple_futures() {
+ let pool = task::LocalPoolHandle::new(2);
+
+ let join_handle1 = pool.spawn_pinned(|| {
+ let local_data = Rc::new("test1");
+ async move { local_data.to_string() }
+ });
+ let join_handle2 = pool.spawn_pinned(|| {
+ let local_data = Rc::new("test2");
+ async move { local_data.to_string() }
+ });
+
+ assert_eq!(join_handle1.await.unwrap(), "test1");
+ assert_eq!(join_handle2.await.unwrap(), "test2");
+}
+
+/// A panic in the spawned task causes the join handle to return an error.
+/// But, you can continue to spawn tasks.
+#[tokio::test]
+async fn task_panic_propagates() {
+ let pool = task::LocalPoolHandle::new(1);
+
+ let join_handle = pool.spawn_pinned(|| async {
+ panic!("Test panic");
+ });
+
+ let result = join_handle.await;
+ assert!(result.is_err());
+ let error = result.unwrap_err();
+ assert!(error.is_panic());
+ let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
+ assert_eq!(*panic_str, "Test panic");
+
+ // Trying again with a "safe" task still works
+ let join_handle = pool.spawn_pinned(|| async { "test" });
+ let result = join_handle.await;
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "test");
+}
+
+/// A panic during task creation causes the join handle to return an error.
+/// But, you can continue to spawn tasks.
+#[tokio::test]
+async fn callback_panic_does_not_kill_worker() {
+ let pool = task::LocalPoolHandle::new(1);
+
+ let join_handle = pool.spawn_pinned(|| {
+ panic!("Test panic");
+ #[allow(unreachable_code)]
+ async {}
+ });
+
+ let result = join_handle.await;
+ assert!(result.is_err());
+ let error = result.unwrap_err();
+ assert!(error.is_panic());
+ let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
+ assert_eq!(*panic_str, "Test panic");
+
+ // Trying again with a "safe" callback works
+ let join_handle = pool.spawn_pinned(|| async { "test" });
+ let result = join_handle.await;
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "test");
+}
+
+/// Canceling the task via the returned join handle cancels the spawned task
+/// (which has a different, internal join handle).
+#[tokio::test]
+async fn task_cancellation_propagates() {
+ let pool = task::LocalPoolHandle::new(1);
+ let notify_dropped = Arc::new(());
+ let weak_notify_dropped = Arc::downgrade(&notify_dropped);
+
+ let (start_sender, start_receiver) = tokio::sync::oneshot::channel();
+ let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>();
+ let join_handle = pool.spawn_pinned(|| async move {
+ let _drop_sender = drop_sender;
+ // Move the Arc into the task
+ let _notify_dropped = notify_dropped;
+ let _ = start_sender.send(());
+
+ // Keep the task running until it gets aborted
+ futures::future::pending::<()>().await;
+ });
+
+ // Wait for the task to start
+ let _ = start_receiver.await;
+
+ join_handle.abort();
+
+ // Wait for the inner task to abort, dropping the sender.
+ // The top level join handle aborts quicker than the inner task (the abort
+ // needs to propagate and get processed on the worker thread), so we can't
+ // just await the top level join handle.
+ let _ = drop_receiver.await;
+
+ // Check that the Arc has been dropped. This verifies that the inner task
+ // was canceled as well.
+ assert!(weak_notify_dropped.upgrade().is_none());
+}
+
+/// Tasks should be given to the least burdened worker. When spawning two tasks
+/// on a pool with two empty workers the tasks should be spawned on separate
+/// workers.
+#[tokio::test]
+async fn tasks_are_balanced() {
+ let pool = task::LocalPoolHandle::new(2);
+
+ // Spawn a task so one thread has a task count of 1
+ let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel();
+ let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel();
+ let join_handle1 = pool.spawn_pinned(|| async move {
+ let _ = start_sender1.send(());
+ let _ = end_receiver1.await;
+ std::thread::current().id()
+ });
+
+ // Wait for the first task to start up
+ let _ = start_receiver1.await;
+
+ // This task should be spawned on the other thread
+ let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel();
+ let join_handle2 = pool.spawn_pinned(|| async move {
+ let _ = start_sender2.send(());
+ std::thread::current().id()
+ });
+
+ // Wait for the second task to start up
+ let _ = start_receiver2.await;
+
+ // Allow the first task to end
+ let _ = end_sender1.send(());
+
+ let thread_id1 = join_handle1.await.unwrap();
+ let thread_id2 = join_handle2.await.unwrap();
+
+ // Since the first task was active when the second task spawned, they should
+ // be on separate workers/threads.
+ assert_ne!(thread_id1, thread_id2);
+}
+
+#[tokio::test]
+async fn spawn_by_idx() {
+ let pool = task::LocalPoolHandle::new(3);
+ let barrier = Arc::new(Barrier::new(4));
+ let barrier1 = barrier.clone();
+ let barrier2 = barrier.clone();
+ let barrier3 = barrier.clone();
+
+ let handle1 = pool.spawn_pinned_by_idx(
+ || async move {
+ barrier1.wait().await;
+ std::thread::current().id()
+ },
+ 0,
+ );
+ let _ = pool.spawn_pinned_by_idx(
+ || async move {
+ barrier2.wait().await;
+ std::thread::current().id()
+ },
+ 0,
+ );
+ let handle2 = pool.spawn_pinned_by_idx(
+ || async move {
+ barrier3.wait().await;
+ std::thread::current().id()
+ },
+ 1,
+ );
+
+ let loads = pool.get_task_loads_for_each_worker();
+ barrier.wait().await;
+ assert_eq!(loads[0], 2);
+ assert_eq!(loads[1], 1);
+ assert_eq!(loads[2], 0);
+
+ let thread_id1 = handle1.await.unwrap();
+ let thread_id2 = handle2.await.unwrap();
+
+ assert_ne!(thread_id1, thread_id2);
+}
diff --git a/tests/sync_cancellation_token.rs b/tests/sync_cancellation_token.rs
new file mode 100644
index 0000000..279d74f
--- /dev/null
+++ b/tests/sync_cancellation_token.rs
@@ -0,0 +1,447 @@
+#![warn(rust_2018_idioms)]
+
+use tokio::pin;
+use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
+
+use core::future::Future;
+use core::task::{Context, Poll};
+use futures_test::task::new_count_waker;
+
+#[test]
+fn cancel_token() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+ assert!(!token.is_cancelled());
+
+ let wait_fut = token.cancelled();
+ pin!(wait_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ wait_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ let wait_fut_2 = token.cancelled();
+ pin!(wait_fut_2);
+
+ token.cancel();
+ assert_eq!(wake_counter, 1);
+ assert!(token.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ wait_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+}
+
+#[test]
+fn cancel_token_owned() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+ assert!(!token.is_cancelled());
+
+ let wait_fut = token.clone().cancelled_owned();
+ pin!(wait_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ wait_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ let wait_fut_2 = token.clone().cancelled_owned();
+ pin!(wait_fut_2);
+
+ token.cancel();
+ assert_eq!(wake_counter, 1);
+ assert!(token.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ wait_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+}
+
+#[test]
+fn cancel_token_owned_drop_test() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+
+ let future = token.cancelled_owned();
+ pin!(future);
+
+ assert_eq!(
+ Poll::Pending,
+ future.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ // let future be dropped while pinned and under pending state to
+ // find potential memory related bugs.
+}
+
+#[test]
+fn cancel_child_token_through_parent() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+
+ let child_token = token.child_token();
+ assert!(!child_token.is_cancelled());
+
+ let child_fut = child_token.cancelled();
+ pin!(child_fut);
+ let parent_fut = token.cancelled();
+ pin!(parent_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ token.cancel();
+ assert_eq!(wake_counter, 2);
+ assert!(token.is_cancelled());
+ assert!(child_token.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+}
+
+#[test]
+fn cancel_grandchild_token_through_parent_if_child_was_dropped() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+
+ let intermediate_token = token.child_token();
+ let child_token = intermediate_token.child_token();
+ drop(intermediate_token);
+ assert!(!child_token.is_cancelled());
+
+ let child_fut = child_token.cancelled();
+ pin!(child_fut);
+ let parent_fut = token.cancelled();
+ pin!(parent_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ token.cancel();
+ assert_eq!(wake_counter, 2);
+ assert!(token.is_cancelled());
+ assert!(child_token.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+}
+
+#[test]
+fn cancel_child_token_without_parent() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+
+ let child_token_1 = token.child_token();
+
+ let child_fut = child_token_1.cancelled();
+ pin!(child_fut);
+ let parent_fut = token.cancelled();
+ pin!(parent_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ child_token_1.cancel();
+ assert_eq!(wake_counter, 1);
+ assert!(!token.is_cancelled());
+ assert!(child_token_1.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+
+ let child_token_2 = token.child_token();
+ let child_fut_2 = child_token_2.cancelled();
+ pin!(child_fut_2);
+
+ assert_eq!(
+ Poll::Pending,
+ child_fut_2.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+
+ token.cancel();
+ assert_eq!(wake_counter, 3);
+ assert!(token.is_cancelled());
+ assert!(child_token_2.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ child_fut_2.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+}
+
+#[test]
+fn create_child_token_after_parent_was_cancelled() {
+ for drop_child_first in [true, false].iter().cloned() {
+ let (waker, wake_counter) = new_count_waker();
+ let token = CancellationToken::new();
+ token.cancel();
+
+ let child_token = token.child_token();
+ assert!(child_token.is_cancelled());
+
+ {
+ let child_fut = child_token.cancelled();
+ pin!(child_fut);
+ let parent_fut = token.cancelled();
+ pin!(parent_fut);
+
+ assert_eq!(
+ Poll::Ready(()),
+ child_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+ }
+
+ if drop_child_first {
+ drop(child_token);
+ drop(token);
+ } else {
+ drop(token);
+ drop(child_token);
+ }
+ }
+}
+
+#[test]
+fn drop_multiple_child_tokens() {
+ for drop_first_child_first in &[true, false] {
+ let token = CancellationToken::new();
+ let mut child_tokens = [None, None, None];
+ for child in &mut child_tokens {
+ *child = Some(token.child_token());
+ }
+
+ assert!(!token.is_cancelled());
+ assert!(!child_tokens[0].as_ref().unwrap().is_cancelled());
+
+ for i in 0..child_tokens.len() {
+ if *drop_first_child_first {
+ child_tokens[i] = None;
+ } else {
+ child_tokens[child_tokens.len() - 1 - i] = None;
+ }
+ assert!(!token.is_cancelled());
+ }
+
+ drop(token);
+ }
+}
+
+#[test]
+fn cancel_only_all_descendants() {
+ // ARRANGE
+ let (waker, wake_counter) = new_count_waker();
+
+ let parent_token = CancellationToken::new();
+ let token = parent_token.child_token();
+ let sibling_token = parent_token.child_token();
+ let child1_token = token.child_token();
+ let child2_token = token.child_token();
+ let grandchild_token = child1_token.child_token();
+ let grandchild2_token = child1_token.child_token();
+ let great_grandchild_token = grandchild_token.child_token();
+
+ assert!(!parent_token.is_cancelled());
+ assert!(!token.is_cancelled());
+ assert!(!sibling_token.is_cancelled());
+ assert!(!child1_token.is_cancelled());
+ assert!(!child2_token.is_cancelled());
+ assert!(!grandchild_token.is_cancelled());
+ assert!(!grandchild2_token.is_cancelled());
+ assert!(!great_grandchild_token.is_cancelled());
+
+ let parent_fut = parent_token.cancelled();
+ let fut = token.cancelled();
+ let sibling_fut = sibling_token.cancelled();
+ let child1_fut = child1_token.cancelled();
+ let child2_fut = child2_token.cancelled();
+ let grandchild_fut = grandchild_token.cancelled();
+ let grandchild2_fut = grandchild2_token.cancelled();
+ let great_grandchild_fut = great_grandchild_token.cancelled();
+
+ pin!(parent_fut);
+ pin!(fut);
+ pin!(sibling_fut);
+ pin!(child1_fut);
+ pin!(child2_fut);
+ pin!(grandchild_fut);
+ pin!(grandchild2_fut);
+ pin!(great_grandchild_fut);
+
+ assert_eq!(
+ Poll::Pending,
+ parent_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ sibling_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ child1_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ child2_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ grandchild_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ grandchild2_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Pending,
+ great_grandchild_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 0);
+
+ // ACT
+ token.cancel();
+
+ // ASSERT
+ assert_eq!(wake_counter, 6);
+ assert!(!parent_token.is_cancelled());
+ assert!(token.is_cancelled());
+ assert!(!sibling_token.is_cancelled());
+ assert!(child1_token.is_cancelled());
+ assert!(child2_token.is_cancelled());
+ assert!(grandchild_token.is_cancelled());
+ assert!(grandchild2_token.is_cancelled());
+ assert!(great_grandchild_token.is_cancelled());
+
+ assert_eq!(
+ Poll::Ready(()),
+ fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ child1_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ child2_fut.as_mut().poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ grandchild_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ grandchild2_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(
+ Poll::Ready(()),
+ great_grandchild_fut
+ .as_mut()
+ .poll(&mut Context::from_waker(&waker))
+ );
+ assert_eq!(wake_counter, 6);
+}
+
+#[test]
+fn drop_parent_before_child_tokens() {
+ let token = CancellationToken::new();
+ let child1 = token.child_token();
+ let child2 = token.child_token();
+
+ drop(token);
+ assert!(!child1.is_cancelled());
+
+ drop(child1);
+ drop(child2);
+}
+
+#[test]
+fn derives_send_sync() {
+ fn assert_send<T: Send>() {}
+ fn assert_sync<T: Sync>() {}
+
+ assert_send::<CancellationToken>();
+ assert_sync::<CancellationToken>();
+
+ assert_send::<WaitForCancellationFuture<'static>>();
+ assert_sync::<WaitForCancellationFuture<'static>>();
+}
diff --git a/tests/task_join_map.rs b/tests/task_join_map.rs
new file mode 100644
index 0000000..cef08b2
--- /dev/null
+++ b/tests/task_join_map.rs
@@ -0,0 +1,275 @@
+#![warn(rust_2018_idioms)]
+#![cfg(all(feature = "rt", tokio_unstable))]
+
+use tokio::sync::oneshot;
+use tokio::time::Duration;
+use tokio_util::task::JoinMap;
+
+use futures::future::FutureExt;
+
+fn rt() -> tokio::runtime::Runtime {
+ tokio::runtime::Builder::new_current_thread()
+ .build()
+ .unwrap()
+}
+
+#[tokio::test(start_paused = true)]
+async fn test_with_sleep() {
+ let mut map = JoinMap::new();
+
+ for i in 0..10 {
+ map.spawn(i, async move { i });
+ assert_eq!(map.len(), 1 + i);
+ }
+ map.detach_all();
+ assert_eq!(map.len(), 0);
+
+ assert!(matches!(map.join_next().await, None));
+
+ for i in 0..10 {
+ map.spawn(i, async move {
+ tokio::time::sleep(Duration::from_secs(i as u64)).await;
+ i
+ });
+ assert_eq!(map.len(), 1 + i);
+ }
+
+ let mut seen = [false; 10];
+ while let Some((k, res)) = map.join_next().await {
+ seen[k] = true;
+ assert_eq!(res.expect("task should have completed successfully"), k);
+ }
+
+ for was_seen in &seen {
+ assert!(was_seen);
+ }
+ assert!(matches!(map.join_next().await, None));
+
+ // Do it again.
+ for i in 0..10 {
+ map.spawn(i, async move {
+ tokio::time::sleep(Duration::from_secs(i as u64)).await;
+ i
+ });
+ }
+
+ let mut seen = [false; 10];
+ while let Some((k, res)) = map.join_next().await {
+ seen[k] = true;
+ assert_eq!(res.expect("task should have completed successfully"), k);
+ }
+
+ for was_seen in &seen {
+ assert!(was_seen);
+ }
+ assert!(matches!(map.join_next().await, None));
+}
+
+#[tokio::test]
+async fn test_abort_on_drop() {
+ let mut map = JoinMap::new();
+
+ let mut recvs = Vec::new();
+
+ for i in 0..16 {
+ let (send, recv) = oneshot::channel::<()>();
+ recvs.push(recv);
+
+ map.spawn(i, async {
+ // This task will never complete on its own.
+ futures::future::pending::<()>().await;
+ drop(send);
+ });
+ }
+
+ drop(map);
+
+ for recv in recvs {
+ // The task is aborted soon and we will receive an error.
+ assert!(recv.await.is_err());
+ }
+}
+
+#[tokio::test]
+async fn alternating() {
+ let mut map = JoinMap::new();
+
+ assert_eq!(map.len(), 0);
+ map.spawn(1, async {});
+ assert_eq!(map.len(), 1);
+ map.spawn(2, async {});
+ assert_eq!(map.len(), 2);
+
+ for i in 0..16 {
+ let (_, res) = map.join_next().await.unwrap();
+ assert!(res.is_ok());
+ assert_eq!(map.len(), 1);
+ map.spawn(i, async {});
+ assert_eq!(map.len(), 2);
+ }
+}
+
+#[tokio::test(start_paused = true)]
+async fn abort_by_key() {
+ let mut map = JoinMap::new();
+ let mut num_canceled = 0;
+ let mut num_completed = 0;
+ for i in 0..16 {
+ map.spawn(i, async move {
+ tokio::time::sleep(Duration::from_secs(i as u64)).await;
+ });
+ }
+
+ for i in 0..16 {
+ if i % 2 != 0 {
+ // abort odd-numbered tasks.
+ map.abort(&i);
+ }
+ }
+
+ while let Some((key, res)) = map.join_next().await {
+ match res {
+ Ok(()) => {
+ num_completed += 1;
+ assert_eq!(key % 2, 0);
+ assert!(!map.contains_key(&key));
+ }
+ Err(e) => {
+ num_canceled += 1;
+ assert!(e.is_cancelled());
+ assert_ne!(key % 2, 0);
+ assert!(!map.contains_key(&key));
+ }
+ }
+ }
+
+ assert_eq!(num_canceled, 8);
+ assert_eq!(num_completed, 8);
+}
+
+#[tokio::test(start_paused = true)]
+async fn abort_by_predicate() {
+ let mut map = JoinMap::new();
+ let mut num_canceled = 0;
+ let mut num_completed = 0;
+ for i in 0..16 {
+ map.spawn(i, async move {
+ tokio::time::sleep(Duration::from_secs(i as u64)).await;
+ });
+ }
+
+ // abort odd-numbered tasks.
+ map.abort_matching(|key| key % 2 != 0);
+
+ while let Some((key, res)) = map.join_next().await {
+ match res {
+ Ok(()) => {
+ num_completed += 1;
+ assert_eq!(key % 2, 0);
+ assert!(!map.contains_key(&key));
+ }
+ Err(e) => {
+ num_canceled += 1;
+ assert!(e.is_cancelled());
+ assert_ne!(key % 2, 0);
+ assert!(!map.contains_key(&key));
+ }
+ }
+ }
+
+ assert_eq!(num_canceled, 8);
+ assert_eq!(num_completed, 8);
+}
+
+#[test]
+fn runtime_gone() {
+ let mut map = JoinMap::new();
+ {
+ let rt = rt();
+ map.spawn_on("key", async { 1 }, rt.handle());
+ drop(rt);
+ }
+
+ let (key, res) = rt().block_on(map.join_next()).unwrap();
+ assert_eq!(key, "key");
+ assert!(res.unwrap_err().is_cancelled());
+}
+
+// This ensures that `join_next` works correctly when the coop budget is
+// exhausted.
+#[tokio::test(flavor = "current_thread")]
+async fn join_map_coop() {
+ // Large enough to trigger coop.
+ const TASK_NUM: u32 = 1000;
+
+ static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0);
+
+ let mut map = JoinMap::new();
+
+ for i in 0..TASK_NUM {
+ map.spawn(i, async move {
+ SEM.add_permits(1);
+ i
+ });
+ }
+
+ // Wait for all tasks to complete.
+ //
+ // Since this is a `current_thread` runtime, there's no race condition
+ // between the last permit being added and the task completing.
+ let _ = SEM.acquire_many(TASK_NUM).await.unwrap();
+
+ let mut count = 0;
+ let mut coop_count = 0;
+ loop {
+ match map.join_next().now_or_never() {
+ Some(Some((key, Ok(i)))) => assert_eq!(key, i),
+ Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err),
+ None => {
+ coop_count += 1;
+ tokio::task::yield_now().await;
+ continue;
+ }
+ Some(None) => break,
+ }
+
+ count += 1;
+ }
+ assert!(coop_count >= 1);
+ assert_eq!(count, TASK_NUM);
+}
+
+#[tokio::test(start_paused = true)]
+async fn abort_all() {
+ let mut map: JoinMap<usize, ()> = JoinMap::new();
+
+ for i in 0..5 {
+ map.spawn(i, futures::future::pending());
+ }
+ for i in 5..10 {
+ map.spawn(i, async {
+ tokio::time::sleep(Duration::from_secs(1)).await;
+ });
+ }
+
+ // The join map will now have 5 pending tasks and 5 ready tasks.
+ tokio::time::sleep(Duration::from_secs(2)).await;
+
+ map.abort_all();
+ assert_eq!(map.len(), 10);
+
+ let mut count = 0;
+ let mut seen = [false; 10];
+ while let Some((k, res)) = map.join_next().await {
+ seen[k] = true;
+ if let Err(err) = res {
+ assert!(err.is_cancelled());
+ }
+ count += 1;
+ }
+ assert_eq!(count, 10);
+ assert_eq!(map.len(), 0);
+ for was_seen in &seen {
+ assert!(was_seen);
+ }
+}
diff --git a/tests/time_delay_queue.rs b/tests/time_delay_queue.rs
new file mode 100644
index 0000000..9ceae34
--- /dev/null
+++ b/tests/time_delay_queue.rs
@@ -0,0 +1,820 @@
+#![allow(clippy::disallowed_names)]
+#![warn(rust_2018_idioms)]
+#![cfg(feature = "full")]
+
+use tokio::time::{self, sleep, sleep_until, Duration, Instant};
+use tokio_test::{assert_pending, assert_ready, task};
+use tokio_util::time::DelayQueue;
+
+macro_rules! poll {
+ ($queue:ident) => {
+ $queue.enter(|cx, mut queue| queue.poll_expired(cx))
+ };
+}
+
+macro_rules! assert_ready_some {
+ ($e:expr) => {{
+ match assert_ready!($e) {
+ Some(v) => v,
+ None => panic!("None"),
+ }
+ }};
+}
+
+#[tokio::test]
+async fn single_immediate_delay() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+ let _key = queue.insert_at("foo", Instant::now());
+
+ // Advance time by 1ms to handle thee rounding
+ sleep(ms(1)).await;
+
+ assert_ready_some!(poll!(queue));
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none())
+}
+
+#[tokio::test]
+async fn multi_immediate_delays() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let _k = queue.insert_at("1", Instant::now());
+ let _k = queue.insert_at("2", Instant::now());
+ let _k = queue.insert_at("3", Instant::now());
+
+ sleep(ms(1)).await;
+
+ let mut res = vec![];
+
+ while res.len() < 3 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none());
+
+ res.sort_unstable();
+
+ assert_eq!("1", res[0]);
+ assert_eq!("2", res[1]);
+ assert_eq!("3", res[2]);
+}
+
+#[tokio::test]
+async fn single_short_delay() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+ let _key = queue.insert_at("foo", Instant::now() + ms(5));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(1)).await;
+
+ assert!(!queue.is_woken());
+
+ sleep(ms(5)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(*entry.get_ref(), "foo");
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none());
+}
+
+#[tokio::test]
+async fn multi_delay_at_start() {
+ time::pause();
+
+ let long = 262_144 + 9 * 4096;
+ let delays = &[1000, 2, 234, long, 60, 10];
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ // Setup the delays
+ for &i in delays {
+ let _key = queue.insert_at(i, Instant::now() + ms(i));
+ }
+
+ assert_pending!(poll!(queue));
+ assert!(!queue.is_woken());
+
+ let start = Instant::now();
+ for elapsed in 0..1200 {
+ println!("elapsed: {:?}", elapsed);
+ let elapsed = elapsed + 1;
+ tokio::time::sleep_until(start + ms(elapsed)).await;
+
+ if delays.contains(&elapsed) {
+ assert!(queue.is_woken());
+ assert_ready!(poll!(queue));
+ assert_pending!(poll!(queue));
+ } else if queue.is_woken() {
+ let cascade = &[192, 960];
+ assert!(
+ cascade.contains(&elapsed),
+ "elapsed={} dt={:?}",
+ elapsed,
+ Instant::now() - start
+ );
+
+ assert_pending!(poll!(queue));
+ }
+ }
+ println!("finished multi_delay_start");
+}
+
+#[tokio::test]
+async fn insert_in_past_fires_immediately() {
+ println!("running insert_in_past_fires_immediately");
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+ let now = Instant::now();
+
+ sleep(ms(10)).await;
+
+ queue.insert_at("foo", now);
+
+ assert_ready!(poll!(queue));
+ println!("finished insert_in_past_fires_immediately");
+}
+
+#[tokio::test]
+async fn remove_entry() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let key = queue.insert_at("foo", Instant::now() + ms(5));
+
+ assert_pending!(poll!(queue));
+
+ let entry = queue.remove(&key);
+ assert_eq!(entry.into_inner(), "foo");
+
+ sleep(ms(10)).await;
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none());
+}
+
+#[tokio::test]
+async fn reset_entry() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+ let key = queue.insert_at("foo", now + ms(5));
+
+ assert_pending!(poll!(queue));
+ sleep(ms(1)).await;
+
+ queue.reset_at(&key, now + ms(10));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(7)).await;
+
+ assert!(!queue.is_woken());
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(3)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(*entry.get_ref(), "foo");
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none())
+}
+
+// Reproduces tokio-rs/tokio#849.
+#[tokio::test]
+async fn reset_much_later() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+ sleep(ms(1)).await;
+
+ let key = queue.insert_at("foo", now + ms(200));
+ assert_pending!(poll!(queue));
+
+ sleep(ms(3)).await;
+
+ queue.reset_at(&key, now + ms(10));
+
+ sleep(ms(20)).await;
+
+ assert!(queue.is_woken());
+}
+
+// Reproduces tokio-rs/tokio#849.
+#[tokio::test]
+async fn reset_twice() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+ let now = Instant::now();
+
+ sleep(ms(1)).await;
+
+ let key = queue.insert_at("foo", now + ms(200));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(3)).await;
+
+ queue.reset_at(&key, now + ms(50));
+
+ sleep(ms(20)).await;
+
+ queue.reset_at(&key, now + ms(40));
+
+ sleep(ms(20)).await;
+
+ assert!(queue.is_woken());
+}
+
+/// Regression test: Given an entry inserted with a deadline in the past, so
+/// that it is placed directly on the expired queue, reset the entry to a
+/// deadline in the future. Validate that this leaves the entry and queue in an
+/// internally consistent state by running an additional reset on the entry
+/// before polling it to completion.
+#[tokio::test]
+async fn repeatedly_reset_entry_inserted_as_expired() {
+ time::pause();
+ let mut queue = task::spawn(DelayQueue::new());
+ let now = Instant::now();
+
+ let key = queue.insert_at("foo", now - ms(100));
+
+ queue.reset_at(&key, now + ms(100));
+ queue.reset_at(&key, now + ms(50));
+
+ assert_pending!(poll!(queue));
+
+ time::sleep_until(now + ms(60)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "foo");
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none());
+}
+
+#[tokio::test]
+async fn remove_expired_item() {
+ time::pause();
+
+ let mut queue = DelayQueue::new();
+
+ let now = Instant::now();
+
+ sleep(ms(10)).await;
+
+ let key = queue.insert_at("foo", now);
+
+ let entry = queue.remove(&key);
+ assert_eq!(entry.into_inner(), "foo");
+}
+
+/// Regression test: it should be possible to remove entries which fall in the
+/// 0th slot of the internal timer wheel — that is, entries whose expiration
+/// (a) falls at the beginning of one of the wheel's hierarchical levels and (b)
+/// is equal to the wheel's current elapsed time.
+#[tokio::test]
+async fn remove_at_timer_wheel_threshold() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let key1 = queue.insert_at("foo", now + ms(64));
+ let key2 = queue.insert_at("bar", now + ms(64));
+
+ sleep(ms(80)).await;
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+
+ match entry {
+ "foo" => {
+ let entry = queue.remove(&key2).into_inner();
+ assert_eq!(entry, "bar");
+ }
+ "bar" => {
+ let entry = queue.remove(&key1).into_inner();
+ assert_eq!(entry, "foo");
+ }
+ other => panic!("other: {:?}", other),
+ }
+}
+
+#[tokio::test]
+async fn expires_before_last_insert() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("foo", now + ms(10_000));
+
+ // Delay should be set to 8.192s here.
+ assert_pending!(poll!(queue));
+
+ // Delay should be set to the delay of the new item here
+ queue.insert_at("bar", now + ms(600));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(600)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "bar");
+}
+
+#[tokio::test]
+async fn multi_reset() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let one = queue.insert_at("one", now + ms(200));
+ let two = queue.insert_at("two", now + ms(250));
+
+ assert_pending!(poll!(queue));
+
+ queue.reset_at(&one, now + ms(300));
+ queue.reset_at(&two, now + ms(350));
+ queue.reset_at(&one, now + ms(400));
+
+ sleep(ms(310)).await;
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(50)).await;
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(*entry.get_ref(), "two");
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(50)).await;
+
+ let entry = assert_ready_some!(poll!(queue));
+ assert_eq!(*entry.get_ref(), "one");
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none())
+}
+
+#[tokio::test]
+async fn expire_first_key_when_reset_to_expire_earlier() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let one = queue.insert_at("one", now + ms(200));
+ queue.insert_at("two", now + ms(250));
+
+ assert_pending!(poll!(queue));
+
+ queue.reset_at(&one, now + ms(100));
+
+ sleep(ms(100)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "one");
+}
+
+#[tokio::test]
+async fn expire_second_key_when_reset_to_expire_earlier() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("one", now + ms(200));
+ let two = queue.insert_at("two", now + ms(250));
+
+ assert_pending!(poll!(queue));
+
+ queue.reset_at(&two, now + ms(100));
+
+ sleep(ms(100)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "two");
+}
+
+#[tokio::test]
+async fn reset_first_expiring_item_to_expire_later() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let one = queue.insert_at("one", now + ms(200));
+ let _two = queue.insert_at("two", now + ms(250));
+
+ assert_pending!(poll!(queue));
+
+ queue.reset_at(&one, now + ms(300));
+ sleep(ms(250)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "two");
+}
+
+#[tokio::test]
+async fn insert_before_first_after_poll() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let _one = queue.insert_at("one", now + ms(200));
+
+ assert_pending!(poll!(queue));
+
+ let _two = queue.insert_at("two", now + ms(100));
+
+ sleep(ms(99)).await;
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(1)).await;
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "two");
+}
+
+#[tokio::test]
+async fn insert_after_ready_poll() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("1", now + ms(100));
+ queue.insert_at("2", now + ms(100));
+ queue.insert_at("3", now + ms(100));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(100)).await;
+
+ assert!(queue.is_woken());
+
+ let mut res = vec![];
+
+ while res.len() < 3 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ queue.insert_at("foo", now + ms(500));
+ }
+
+ res.sort_unstable();
+
+ assert_eq!("1", res[0]);
+ assert_eq!("2", res[1]);
+ assert_eq!("3", res[2]);
+}
+
+#[tokio::test]
+async fn reset_later_after_slot_starts() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let foo = queue.insert_at("foo", now + ms(100));
+
+ assert_pending!(poll!(queue));
+
+ sleep_until(now + Duration::from_millis(80)).await;
+
+ assert!(!queue.is_woken());
+
+ // At this point the queue hasn't been polled, so `elapsed` on the wheel
+ // for the queue is still at 0 and hence the 1ms resolution slots cover
+ // [0-64). Resetting the time on the entry to 120 causes it to get put in
+ // the [64-128) slot. As the queue knows that the first entry is within
+ // that slot, but doesn't know when, it must wake immediately to advance
+ // the wheel.
+ queue.reset_at(&foo, now + ms(120));
+ assert!(queue.is_woken());
+
+ assert_pending!(poll!(queue));
+
+ sleep_until(now + Duration::from_millis(119)).await;
+ assert!(!queue.is_woken());
+
+ sleep(ms(1)).await;
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "foo");
+}
+
+#[tokio::test]
+async fn reset_inserted_expired() {
+ time::pause();
+ let mut queue = task::spawn(DelayQueue::new());
+ let now = Instant::now();
+
+ let key = queue.insert_at("foo", now - ms(100));
+
+ // this causes the panic described in #2473
+ queue.reset_at(&key, now + ms(100));
+
+ assert_eq!(1, queue.len());
+
+ sleep(ms(200)).await;
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "foo");
+
+ assert_eq!(queue.len(), 0);
+}
+
+#[tokio::test]
+async fn reset_earlier_after_slot_starts() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let foo = queue.insert_at("foo", now + ms(200));
+
+ assert_pending!(poll!(queue));
+
+ sleep_until(now + Duration::from_millis(80)).await;
+
+ assert!(!queue.is_woken());
+
+ // At this point the queue hasn't been polled, so `elapsed` on the wheel
+ // for the queue is still at 0 and hence the 1ms resolution slots cover
+ // [0-64). Resetting the time on the entry to 120 causes it to get put in
+ // the [64-128) slot. As the queue knows that the first entry is within
+ // that slot, but doesn't know when, it must wake immediately to advance
+ // the wheel.
+ queue.reset_at(&foo, now + ms(120));
+ assert!(queue.is_woken());
+
+ assert_pending!(poll!(queue));
+
+ sleep_until(now + Duration::from_millis(119)).await;
+ assert!(!queue.is_woken());
+
+ sleep(ms(1)).await;
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "foo");
+}
+
+#[tokio::test]
+async fn insert_in_past_after_poll_fires_immediately() {
+ time::pause();
+
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("foo", now + ms(200));
+
+ assert_pending!(poll!(queue));
+
+ sleep(ms(80)).await;
+
+ assert!(!queue.is_woken());
+ queue.insert_at("bar", now + ms(40));
+
+ assert!(queue.is_woken());
+
+ let entry = assert_ready_some!(poll!(queue)).into_inner();
+ assert_eq!(entry, "bar");
+}
+
+#[tokio::test]
+async fn delay_queue_poll_expired_when_empty() {
+ let mut delay_queue = task::spawn(DelayQueue::new());
+ let key = delay_queue.insert(0, std::time::Duration::from_secs(10));
+ assert_pending!(poll!(delay_queue));
+
+ delay_queue.remove(&key);
+ assert!(assert_ready!(poll!(delay_queue)).is_none());
+}
+
+#[tokio::test(start_paused = true)]
+async fn compact_expire_empty() {
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("foo1", now + ms(10));
+ queue.insert_at("foo2", now + ms(10));
+
+ sleep(ms(10)).await;
+
+ let mut res = vec![];
+ while res.len() < 2 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ queue.compact();
+
+ assert_eq!(queue.len(), 0);
+ assert_eq!(queue.capacity(), 0);
+}
+
+#[tokio::test(start_paused = true)]
+async fn compact_remove_empty() {
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ let key1 = queue.insert_at("foo1", now + ms(10));
+ let key2 = queue.insert_at("foo2", now + ms(10));
+
+ queue.remove(&key1);
+ queue.remove(&key2);
+
+ queue.compact();
+
+ assert_eq!(queue.len(), 0);
+ assert_eq!(queue.capacity(), 0);
+}
+
+#[tokio::test(start_paused = true)]
+// Trigger a re-mapping of keys in the slab due to a `compact` call and
+// test removal of re-mapped keys
+async fn compact_remove_remapped_keys() {
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let now = Instant::now();
+
+ queue.insert_at("foo1", now + ms(10));
+ queue.insert_at("foo2", now + ms(10));
+
+ // should be assigned indices 3 and 4
+ let key3 = queue.insert_at("foo3", now + ms(20));
+ let key4 = queue.insert_at("foo4", now + ms(20));
+
+ sleep(ms(10)).await;
+
+ let mut res = vec![];
+ while res.len() < 2 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ // items corresponding to `foo3` and `foo4` will be assigned
+ // new indices here
+ queue.compact();
+
+ queue.insert_at("foo5", now + ms(10));
+
+ // test removal of re-mapped keys
+ let expired3 = queue.remove(&key3);
+ let expired4 = queue.remove(&key4);
+
+ assert_eq!(expired3.into_inner(), "foo3");
+ assert_eq!(expired4.into_inner(), "foo4");
+
+ queue.compact();
+ assert_eq!(queue.len(), 1);
+ assert_eq!(queue.capacity(), 1);
+}
+
+#[tokio::test(start_paused = true)]
+async fn compact_change_deadline() {
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let mut now = Instant::now();
+
+ queue.insert_at("foo1", now + ms(10));
+ queue.insert_at("foo2", now + ms(10));
+
+ // should be assigned indices 3 and 4
+ queue.insert_at("foo3", now + ms(20));
+ let key4 = queue.insert_at("foo4", now + ms(20));
+
+ sleep(ms(10)).await;
+
+ let mut res = vec![];
+ while res.len() < 2 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ // items corresponding to `foo3` and `foo4` should be assigned
+ // new indices
+ queue.compact();
+
+ now = Instant::now();
+
+ queue.insert_at("foo5", now + ms(10));
+ let key6 = queue.insert_at("foo6", now + ms(10));
+
+ queue.reset_at(&key4, now + ms(20));
+ queue.reset_at(&key6, now + ms(20));
+
+ // foo3 and foo5 will expire
+ sleep(ms(10)).await;
+
+ while res.len() < 4 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ sleep(ms(10)).await;
+
+ while res.len() < 6 {
+ let entry = assert_ready_some!(poll!(queue));
+ res.push(entry.into_inner());
+ }
+
+ let entry = assert_ready!(poll!(queue));
+ assert!(entry.is_none());
+}
+
+#[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")]
+#[tokio::test(start_paused = true)]
+async fn remove_after_compact() {
+ let now = Instant::now();
+ let mut queue = DelayQueue::new();
+
+ let foo_key = queue.insert_at("foo", now + ms(10));
+ queue.insert_at("bar", now + ms(20));
+ queue.remove(&foo_key);
+ queue.compact();
+
+ let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
+ queue.remove(&foo_key);
+ }));
+ assert!(panic.is_err());
+}
+
+#[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")]
+#[tokio::test(start_paused = true)]
+async fn remove_after_compact_poll() {
+ let now = Instant::now();
+ let mut queue = task::spawn(DelayQueue::new());
+
+ let foo_key = queue.insert_at("foo", now + ms(10));
+ queue.insert_at("bar", now + ms(20));
+
+ sleep(ms(10)).await;
+ assert_eq!(assert_ready_some!(poll!(queue)).key(), foo_key);
+
+ queue.compact();
+
+ let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
+ queue.remove(&foo_key);
+ }));
+ assert!(panic.is_err());
+}
+
+fn ms(n: u64) -> Duration {
+ Duration::from_millis(n)
+}
diff --git a/tests/udp.rs b/tests/udp.rs
new file mode 100644
index 0000000..1b99806
--- /dev/null
+++ b/tests/udp.rs
@@ -0,0 +1,133 @@
+#![warn(rust_2018_idioms)]
+#![cfg(not(target_os = "wasi"))] // Wasi doesn't support UDP
+
+use tokio::net::UdpSocket;
+use tokio_stream::StreamExt;
+use tokio_util::codec::{Decoder, Encoder, LinesCodec};
+use tokio_util::udp::UdpFramed;
+
+use bytes::{BufMut, BytesMut};
+use futures::future::try_join;
+use futures::future::FutureExt;
+use futures::sink::SinkExt;
+use std::io;
+use std::sync::Arc;
+
+#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))]
+#[tokio::test]
+async fn send_framed_byte_codec() -> std::io::Result<()> {
+ let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?;
+ let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?;
+
+ let a_addr = a_soc.local_addr()?;
+ let b_addr = b_soc.local_addr()?;
+
+ // test sending & receiving bytes
+ {
+ let mut a = UdpFramed::new(a_soc, ByteCodec);
+ let mut b = UdpFramed::new(b_soc, ByteCodec);
+
+ let msg = b"4567";
+
+ let send = a.send((msg, b_addr));
+ let recv = b.next().map(|e| e.unwrap());
+ let (_, received) = try_join(send, recv).await.unwrap();
+
+ let (data, addr) = received;
+ assert_eq!(msg, &*data);
+ assert_eq!(a_addr, addr);
+
+ a_soc = a.into_inner();
+ b_soc = b.into_inner();
+ }
+
+ #[cfg(not(any(target_os = "macos", target_os = "ios")))]
+ // test sending & receiving an empty message
+ {
+ let mut a = UdpFramed::new(a_soc, ByteCodec);
+ let mut b = UdpFramed::new(b_soc, ByteCodec);
+
+ let msg = b"";
+
+ let send = a.send((msg, b_addr));
+ let recv = b.next().map(|e| e.unwrap());
+ let (_, received) = try_join(send, recv).await.unwrap();
+
+ let (data, addr) = received;
+ assert_eq!(msg, &*data);
+ assert_eq!(a_addr, addr);
+ }
+
+ Ok(())
+}
+
+pub struct ByteCodec;
+
+impl Decoder for ByteCodec {
+ type Item = Vec<u8>;
+ type Error = io::Error;
+
+ fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Vec<u8>>, io::Error> {
+ let len = buf.len();
+ Ok(Some(buf.split_to(len).to_vec()))
+ }
+}
+
+impl Encoder<&[u8]> for ByteCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
+ buf.reserve(data.len());
+ buf.put_slice(data);
+ Ok(())
+ }
+}
+
+#[tokio::test]
+async fn send_framed_lines_codec() -> std::io::Result<()> {
+ let a_soc = UdpSocket::bind("127.0.0.1:0").await?;
+ let b_soc = UdpSocket::bind("127.0.0.1:0").await?;
+
+ let a_addr = a_soc.local_addr()?;
+ let b_addr = b_soc.local_addr()?;
+
+ let mut a = UdpFramed::new(a_soc, ByteCodec);
+ let mut b = UdpFramed::new(b_soc, LinesCodec::new());
+
+ let msg = b"1\r\n2\r\n3\r\n".to_vec();
+ a.send((&msg, b_addr)).await?;
+
+ assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr));
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn framed_half() -> std::io::Result<()> {
+ let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
+ let b_soc = a_soc.clone();
+
+ let a_addr = a_soc.local_addr()?;
+ let b_addr = b_soc.local_addr()?;
+
+ let mut a = UdpFramed::new(a_soc, ByteCodec);
+ let mut b = UdpFramed::new(b_soc, LinesCodec::new());
+
+ let msg = b"1\r\n2\r\n3\r\n".to_vec();
+ a.send((&msg, b_addr)).await?;
+
+ let msg = b"4\r\n5\r\n6\r\n".to_vec();
+ a.send((&msg, b_addr)).await?;
+
+ assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr));
+
+ assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr));
+ assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr));
+
+ Ok(())
+}