From b3b4e5b0f0aa14ed7c7eade58516bc1e0d860aeb Mon Sep 17 00:00:00 2001 From: Fedor Tsarev Date: Fri, 24 Mar 2023 09:53:40 +0100 Subject: Import tokio-util crate Bug: b/274759160 Change-Id: I89b5abfb25626af03eccdbf9ef8eca76fd4b22f1 --- CHANGELOG.md | 334 ++++++++ Cargo.toml | 131 +++ Cargo.toml.orig | 66 ++ LICENSE | 25 + METADATA | 19 + MODULE_LICENSE_MIT | 0 OWNERS | 1 + README.md | 13 + src/cfg.rs | 71 ++ src/codec/any_delimiter_codec.rs | 263 ++++++ src/codec/bytes_codec.rs | 86 ++ src/codec/decoder.rs | 184 +++++ src/codec/encoder.rs | 25 + src/codec/framed.rs | 383 +++++++++ src/codec/framed_impl.rs | 312 +++++++ src/codec/framed_read.rs | 199 +++++ src/codec/framed_write.rs | 188 +++++ src/codec/length_delimited.rs | 1043 +++++++++++++++++++++++ src/codec/lines_codec.rs | 230 ++++++ src/codec/mod.rs | 290 +++++++ src/compat.rs | 274 +++++++ src/context.rs | 190 +++++ src/either.rs | 188 +++++ src/io/copy_to_bytes.rs | 68 ++ src/io/inspect.rs | 134 +++ src/io/mod.rs | 31 + src/io/read_buf.rs | 65 ++ src/io/reader_stream.rs | 118 +++ src/io/sink_writer.rs | 117 +++ src/io/stream_reader.rs | 326 ++++++++ src/io/sync_bridge.rs | 143 ++++ src/lib.rs | 205 +++++ src/loom.rs | 1 + src/net/mod.rs | 97 +++ src/net/unix/mod.rs | 18 + src/sync/cancellation_token.rs | 324 ++++++++ src/sync/cancellation_token/guard.rs | 27 + src/sync/cancellation_token/tree_node.rs | 373 +++++++++ src/sync/mod.rs | 15 + src/sync/mpsc.rs | 284 +++++++ src/sync/poll_semaphore.rs | 171 ++++ src/sync/reusable_box.rs | 171 ++++ src/sync/tests/loom_cancellation_token.rs | 176 ++++ src/sync/tests/mod.rs | 1 + src/task/join_map.rs | 807 ++++++++++++++++++ src/task/mod.rs | 12 + src/task/spawn_pinned.rs | 436 ++++++++++ src/time/delay_queue.rs | 1271 +++++++++++++++++++++++++++++ src/time/mod.rs | 47 ++ src/time/wheel/level.rs | 253 ++++++ src/time/wheel/mod.rs | 315 +++++++ src/time/wheel/stack.rs | 28 + src/udp/frame.rs | 249 ++++++ src/udp/mod.rs | 4 + tests/_require_full.rs | 2 + tests/codecs.rs | 464 +++++++++++ tests/context.rs | 25 + tests/framed.rs | 152 ++++ tests/framed_read.rs | 339 ++++++++ tests/framed_stream.rs | 38 + tests/framed_write.rs | 212 +++++ tests/io_inspect.rs | 194 +++++ tests/io_reader_stream.rs | 65 ++ tests/io_sink_writer.rs | 72 ++ tests/io_stream_reader.rs | 35 + tests/io_sync_bridge.rs | 62 ++ tests/length_delimited.rs | 779 ++++++++++++++++++ tests/mpsc.rs | 239 ++++++ tests/panic.rs | 226 +++++ tests/poll_semaphore.rs | 84 ++ tests/reusable_box.rs | 85 ++ tests/spawn_pinned.rs | 237 ++++++ tests/sync_cancellation_token.rs | 447 ++++++++++ tests/task_join_map.rs | 275 +++++++ tests/time_delay_queue.rs | 820 +++++++++++++++++++ tests/udp.rs | 133 +++ 76 files changed, 15787 insertions(+) create mode 100644 CHANGELOG.md create mode 100644 Cargo.toml create mode 100644 Cargo.toml.orig create mode 100644 LICENSE create mode 100644 METADATA create mode 100644 MODULE_LICENSE_MIT create mode 100644 OWNERS create mode 100644 README.md create mode 100644 src/cfg.rs create mode 100644 src/codec/any_delimiter_codec.rs create mode 100644 src/codec/bytes_codec.rs create mode 100644 src/codec/decoder.rs create mode 100644 src/codec/encoder.rs create mode 100644 src/codec/framed.rs create mode 100644 src/codec/framed_impl.rs create mode 100644 src/codec/framed_read.rs create mode 100644 src/codec/framed_write.rs create mode 100644 src/codec/length_delimited.rs create mode 100644 src/codec/lines_codec.rs create mode 100644 src/codec/mod.rs create mode 100644 src/compat.rs create mode 100644 src/context.rs create mode 100644 src/either.rs create mode 100644 src/io/copy_to_bytes.rs create mode 100644 src/io/inspect.rs create mode 100644 src/io/mod.rs create mode 100644 src/io/read_buf.rs create mode 100644 src/io/reader_stream.rs create mode 100644 src/io/sink_writer.rs create mode 100644 src/io/stream_reader.rs create mode 100644 src/io/sync_bridge.rs create mode 100644 src/lib.rs create mode 100644 src/loom.rs create mode 100644 src/net/mod.rs create mode 100644 src/net/unix/mod.rs create mode 100644 src/sync/cancellation_token.rs create mode 100644 src/sync/cancellation_token/guard.rs create mode 100644 src/sync/cancellation_token/tree_node.rs create mode 100644 src/sync/mod.rs create mode 100644 src/sync/mpsc.rs create mode 100644 src/sync/poll_semaphore.rs create mode 100644 src/sync/reusable_box.rs create mode 100644 src/sync/tests/loom_cancellation_token.rs create mode 100644 src/sync/tests/mod.rs create mode 100644 src/task/join_map.rs create mode 100644 src/task/mod.rs create mode 100644 src/task/spawn_pinned.rs create mode 100644 src/time/delay_queue.rs create mode 100644 src/time/mod.rs create mode 100644 src/time/wheel/level.rs create mode 100644 src/time/wheel/mod.rs create mode 100644 src/time/wheel/stack.rs create mode 100644 src/udp/frame.rs create mode 100644 src/udp/mod.rs create mode 100644 tests/_require_full.rs create mode 100644 tests/codecs.rs create mode 100644 tests/context.rs create mode 100644 tests/framed.rs create mode 100644 tests/framed_read.rs create mode 100644 tests/framed_stream.rs create mode 100644 tests/framed_write.rs create mode 100644 tests/io_inspect.rs create mode 100644 tests/io_reader_stream.rs create mode 100644 tests/io_sink_writer.rs create mode 100644 tests/io_stream_reader.rs create mode 100644 tests/io_sync_bridge.rs create mode 100644 tests/length_delimited.rs create mode 100644 tests/mpsc.rs create mode 100644 tests/panic.rs create mode 100644 tests/poll_semaphore.rs create mode 100644 tests/reusable_box.rs create mode 100644 tests/spawn_pinned.rs create mode 100644 tests/sync_cancellation_token.rs create mode 100644 tests/task_join_map.rs create mode 100644 tests/time_delay_queue.rs create mode 100644 tests/udp.rs diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0c11b21 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,334 @@ +# 0.7.7 (February 12, 2023) + +This release reverts the removal of the `Encoder` bound on the `FramedParts` +constructor from [#5280] since it turned out to be a breaking change. ([#5450]) + +[#5450]: https://github.com/tokio-rs/tokio/pull/5450 + +# 0.7.6 (February 10, 2023) + +This release fixes a compilation failure in 0.7.5 when it is used together with +Tokio version 1.21 and unstable features are enabled. ([#5445]) + +[#5445]: https://github.com/tokio-rs/tokio/pull/5445 + +# 0.7.5 (February 9, 2023) + +This release fixes an accidental breaking change where `UnwindSafe` was +accidentally removed from `CancellationToken`. + +### Added +- codec: add `Framed::backpressure_boundary` ([#5124]) +- io: add `InspectReader` and `InspectWriter` ([#5033]) +- io: add `tokio_util::io::{CopyToBytes, SinkWriter}` ([#5070], [#5436]) +- io: impl `std::io::BufRead` on `SyncIoBridge` ([#5265]) +- sync: add `PollSemaphore::poll_acquire_many` ([#5137]) +- sync: add owned future for `CancellationToken` ([#5153]) +- time: add `DelayQueue::try_remove` ([#5052]) + +### Fixed +- codec: fix `LengthDelimitedCodec` buffer over-reservation ([#4997]) +- sync: impl `UnwindSafe` on `CancellationToken` ([#5438]) +- util: remove `Encoder` bound on `FramedParts` constructor ([#5280]) + +### Documented +- io: add lines example for `StreamReader` ([#5145]) + +[#4997]: https://github.com/tokio-rs/tokio/pull/4997 +[#5033]: https://github.com/tokio-rs/tokio/pull/5033 +[#5052]: https://github.com/tokio-rs/tokio/pull/5052 +[#5070]: https://github.com/tokio-rs/tokio/pull/5070 +[#5124]: https://github.com/tokio-rs/tokio/pull/5124 +[#5137]: https://github.com/tokio-rs/tokio/pull/5137 +[#5145]: https://github.com/tokio-rs/tokio/pull/5145 +[#5153]: https://github.com/tokio-rs/tokio/pull/5153 +[#5265]: https://github.com/tokio-rs/tokio/pull/5265 +[#5280]: https://github.com/tokio-rs/tokio/pull/5280 +[#5436]: https://github.com/tokio-rs/tokio/pull/5436 +[#5438]: https://github.com/tokio-rs/tokio/pull/5438 + +# 0.7.4 (September 8, 2022) + +### Added + +- io: add `SyncIoBridge::shutdown()` ([#4938]) +- task: improve `LocalPoolHandle` ([#4680]) + +### Fixed + +- util: add `track_caller` to public APIs ([#4785]) + +### Unstable + +- task: fix compilation errors in `JoinMap` with Tokio v1.21.0 ([#4755]) +- task: remove the unstable, deprecated `JoinMap::join_one` ([#4920]) + +[#4680]: https://github.com/tokio-rs/tokio/pull/4680 +[#4755]: https://github.com/tokio-rs/tokio/pull/4755 +[#4785]: https://github.com/tokio-rs/tokio/pull/4785 +[#4920]: https://github.com/tokio-rs/tokio/pull/4920 +[#4938]: https://github.com/tokio-rs/tokio/pull/4938 + +# 0.7.3 (June 4, 2022) + +### Changed + +- tracing: don't require default tracing features ([#4592]) +- util: simplify implementation of `ReusableBoxFuture` ([#4675]) + +### Added (unstable) + +- task: add `JoinMap` ([#4640], [#4697]) + +[#4592]: https://github.com/tokio-rs/tokio/pull/4592 +[#4640]: https://github.com/tokio-rs/tokio/pull/4640 +[#4675]: https://github.com/tokio-rs/tokio/pull/4675 +[#4697]: https://github.com/tokio-rs/tokio/pull/4697 + +# 0.7.2 (May 14, 2022) + +This release contains a rewrite of `CancellationToken` that fixes a memory leak. ([#4652]) + +[#4652]: https://github.com/tokio-rs/tokio/pull/4652 + +# 0.7.1 (February 21, 2022) + +### Added + +- codec: add `length_field_type` to `LengthDelimitedCodec` builder ([#4508]) +- io: add `StreamReader::into_inner_with_chunk()` ([#4559]) + +### Changed + +- switch from log to tracing ([#4539]) + +### Fixed + +- sync: fix waker update condition in `CancellationToken` ([#4497]) +- bumped tokio dependency to 1.6 to satisfy minimum requirements ([#4490]) + +[#4490]: https://github.com/tokio-rs/tokio/pull/4490 +[#4497]: https://github.com/tokio-rs/tokio/pull/4497 +[#4508]: https://github.com/tokio-rs/tokio/pull/4508 +[#4539]: https://github.com/tokio-rs/tokio/pull/4539 +[#4559]: https://github.com/tokio-rs/tokio/pull/4559 + +# 0.7.0 (February 9, 2022) + +### Added + +- task: add `spawn_pinned` ([#3370]) +- time: add `shrink_to_fit` and `compact` methods to `DelayQueue` ([#4170]) +- codec: improve `Builder::max_frame_length` docs ([#4352]) +- codec: add mutable reference getters for codecs to pinned `Framed` ([#4372]) +- net: add generic trait to combine `UnixListener` and `TcpListener` ([#4385]) +- codec: implement `Framed::map_codec` ([#4427]) +- codec: implement `Encoder` for `BytesCodec` ([#4465]) + +### Changed + +- sync: add lifetime parameter to `ReusableBoxFuture` ([#3762]) +- sync: refactored `PollSender` to fix a subtly broken `Sink` implementation ([#4214]) +- time: remove error case from the infallible `DelayQueue::poll_elapsed` ([#4241]) + +[#3370]: https://github.com/tokio-rs/tokio/pull/3370 +[#4170]: https://github.com/tokio-rs/tokio/pull/4170 +[#4352]: https://github.com/tokio-rs/tokio/pull/4352 +[#4372]: https://github.com/tokio-rs/tokio/pull/4372 +[#4385]: https://github.com/tokio-rs/tokio/pull/4385 +[#4427]: https://github.com/tokio-rs/tokio/pull/4427 +[#4465]: https://github.com/tokio-rs/tokio/pull/4465 +[#3762]: https://github.com/tokio-rs/tokio/pull/3762 +[#4214]: https://github.com/tokio-rs/tokio/pull/4214 +[#4241]: https://github.com/tokio-rs/tokio/pull/4241 + +# 0.6.10 (May 14, 2021) + +This is a backport for the memory leak in `CancellationToken` that was originally fixed in 0.7.2. ([#4652]) + +[#4652]: https://github.com/tokio-rs/tokio/pull/4652 + +# 0.6.9 (October 29, 2021) + +### Added + +- codec: implement `Clone` for `LengthDelimitedCodec` ([#4089]) +- io: add `SyncIoBridge` ([#4146]) + +### Fixed + +- time: update deadline on removal in `DelayQueue` ([#4178]) +- codec: Update stream impl for Framed to return None after Err ([#4166]) + +[#4089]: https://github.com/tokio-rs/tokio/pull/4089 +[#4146]: https://github.com/tokio-rs/tokio/pull/4146 +[#4166]: https://github.com/tokio-rs/tokio/pull/4166 +[#4178]: https://github.com/tokio-rs/tokio/pull/4178 + +# 0.6.8 (September 3, 2021) + +### Added + +- sync: add drop guard for `CancellationToken` ([#3839]) +- compact: added `AsyncSeek` compat ([#4078]) +- time: expose `Key` used in `DelayQueue`'s `Expired` ([#4081]) +- io: add `with_capacity` to `ReaderStream` ([#4086]) + +### Fixed + +- codec: remove unnecessary `doc(cfg(...))` ([#3989]) + +[#3839]: https://github.com/tokio-rs/tokio/pull/3839 +[#4078]: https://github.com/tokio-rs/tokio/pull/4078 +[#4081]: https://github.com/tokio-rs/tokio/pull/4081 +[#4086]: https://github.com/tokio-rs/tokio/pull/4086 +[#3989]: https://github.com/tokio-rs/tokio/pull/3989 + +# 0.6.7 (May 14, 2021) + +### Added + +- udp: make `UdpFramed` take `Borrow` ([#3451]) +- compat: implement `AsRawFd`/`AsRawHandle` for `Compat` ([#3765]) + +[#3451]: https://github.com/tokio-rs/tokio/pull/3451 +[#3765]: https://github.com/tokio-rs/tokio/pull/3765 + +# 0.6.6 (April 12, 2021) + +### Added + +- util: makes `Framed` and `FramedStream` resumable after eof ([#3272]) +- util: add `PollSemaphore::{add_permits, available_permits}` ([#3683]) + +### Fixed + +- chore: avoid allocation if `PollSemaphore` is unused ([#3634]) + +[#3272]: https://github.com/tokio-rs/tokio/pull/3272 +[#3634]: https://github.com/tokio-rs/tokio/pull/3634 +[#3683]: https://github.com/tokio-rs/tokio/pull/3683 + +# 0.6.5 (March 20, 2021) + +### Fixed + +- util: annotate time module as requiring `time` feature ([#3606]) + +[#3606]: https://github.com/tokio-rs/tokio/pull/3606 + +# 0.6.4 (March 9, 2021) + +### Added + +- codec: `AnyDelimiter` codec ([#3406]) +- sync: add pollable `mpsc::Sender` ([#3490]) + +### Fixed + +- codec: `LinesCodec` should only return `MaxLineLengthExceeded` once per line ([#3556]) +- sync: fuse PollSemaphore ([#3578]) + +[#3406]: https://github.com/tokio-rs/tokio/pull/3406 +[#3490]: https://github.com/tokio-rs/tokio/pull/3490 +[#3556]: https://github.com/tokio-rs/tokio/pull/3556 +[#3578]: https://github.com/tokio-rs/tokio/pull/3578 + +# 0.6.3 (January 31, 2021) + +### Added + +- sync: add `ReusableBoxFuture` utility ([#3464]) + +### Changed + +- sync: use `ReusableBoxFuture` for `PollSemaphore` ([#3463]) +- deps: remove `async-stream` dependency ([#3463]) +- deps: remove `tokio-stream` dependency ([#3487]) + +# 0.6.2 (January 21, 2021) + +### Added + +- sync: add pollable `Semaphore` ([#3444]) + +### Fixed + +- time: fix panics on updating `DelayQueue` entries ([#3270]) + +# 0.6.1 (January 12, 2021) + +### Added + +- codec: `get_ref()`, `get_mut()`, `get_pin_mut()` and `into_inner()` for + `Framed`, `FramedRead`, `FramedWrite` and `StreamReader` ([#3364]). +- codec: `write_buffer()` and `write_buffer_mut()` for `Framed` and + `FramedWrite` ([#3387]). + +# 0.6.0 (December 23, 2020) + +### Changed +- depend on `tokio` 1.0. + +### Added +- rt: add constructors to `TokioContext` (#3221). + +# 0.5.1 (December 3, 2020) + +### Added +- io: `poll_read_buf` util fn (#2972). +- io: `poll_write_buf` util fn with vectored write support (#3156). + +# 0.5.0 (October 30, 2020) + +### Changed +- io: update `bytes` to 0.6 (#3071). + +# 0.4.0 (October 15, 2020) + +### Added +- sync: `CancellationToken` for coordinating task cancellation (#2747). +- rt: `TokioContext` sets the Tokio runtime for the duration of a future (#2791) +- io: `StreamReader`/`ReaderStream` map between `AsyncRead` values and `Stream` + of bytes (#2788). +- time: `DelayQueue` to manage many delays (#2897). + +# 0.3.1 (March 18, 2020) + +### Fixed + +- Adjust minimum-supported Tokio version to v0.2.5 to account for an internal + dependency on features in that version of Tokio. ([#2326]) + +# 0.3.0 (March 4, 2020) + +### Changed + +- **Breaking Change**: Change `Encoder` trait to take a generic `Item` parameter, which allows + codec writers to pass references into `Framed` and `FramedWrite` types. ([#1746]) + +### Added + +- Add futures-io/tokio::io compatibility layer. ([#2117]) +- Add `Framed::with_capacity`. ([#2215]) + +### Fixed + +- Use advance over split_to when data is not needed. ([#2198]) + +# 0.2.0 (November 26, 2019) + +- Initial release + +[#3487]: https://github.com/tokio-rs/tokio/pull/3487 +[#3464]: https://github.com/tokio-rs/tokio/pull/3464 +[#3463]: https://github.com/tokio-rs/tokio/pull/3463 +[#3444]: https://github.com/tokio-rs/tokio/pull/3444 +[#3387]: https://github.com/tokio-rs/tokio/pull/3387 +[#3364]: https://github.com/tokio-rs/tokio/pull/3364 +[#3270]: https://github.com/tokio-rs/tokio/pull/3270 +[#2326]: https://github.com/tokio-rs/tokio/pull/2326 +[#2215]: https://github.com/tokio-rs/tokio/pull/2215 +[#2198]: https://github.com/tokio-rs/tokio/pull/2198 +[#2117]: https://github.com/tokio-rs/tokio/pull/2117 +[#1746]: https://github.com/tokio-rs/tokio/pull/1746 diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..221f2e7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,131 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +rust-version = "1.49" +name = "tokio-util" +version = "0.7.7" +authors = ["Tokio Contributors "] +description = """ +Additional utilities for working with Tokio. +""" +homepage = "https://tokio.rs" +readme = "README.md" +categories = ["asynchronous"] +license = "MIT" +repository = "https://github.com/tokio-rs/tokio" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = [ + "--cfg", + "docsrs", + "--cfg", + "tokio_unstable", +] +rustc-args = [ + "--cfg", + "docsrs", + "--cfg", + "tokio_unstable", +] + +[dependencies.bytes] +version = "1.0.0" + +[dependencies.futures-core] +version = "0.3.0" + +[dependencies.futures-io] +version = "0.3.0" +optional = true + +[dependencies.futures-sink] +version = "0.3.0" + +[dependencies.futures-util] +version = "0.3.0" +optional = true + +[dependencies.pin-project-lite] +version = "0.2.0" + +[dependencies.slab] +version = "0.4.4" +optional = true + +[dependencies.tokio] +version = "1.22.0" +features = ["sync"] + +[dependencies.tracing] +version = "0.1.25" +features = ["std"] +optional = true +default-features = false + +[dev-dependencies.async-stream] +version = "0.3.0" + +[dev-dependencies.futures] +version = "0.3.0" + +[dev-dependencies.futures-test] +version = "0.3.5" + +[dev-dependencies.parking_lot] +version = "0.12.0" + +[dev-dependencies.tokio] +version = "1.0.0" +features = ["full"] + +[dev-dependencies.tokio-stream] +version = "0.1" + +[dev-dependencies.tokio-test] +version = "0.4.0" + +[features] +__docs_rs = ["futures-util"] +codec = ["tracing"] +compat = ["futures-io"] +default = [] +full = [ + "codec", + "compat", + "io-util", + "time", + "net", + "rt", +] +io = [] +io-util = [ + "io", + "tokio/rt", + "tokio/io-util", +] +net = ["tokio/net"] +rt = [ + "tokio/rt", + "tokio/sync", + "futures-util", + "hashbrown", +] +time = [ + "tokio/time", + "slab", +] + +[target."cfg(tokio_unstable)".dependencies.hashbrown] +version = "0.12.0" +optional = true diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..267662b --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,66 @@ +[package] +name = "tokio-util" +# When releasing to crates.io: +# - Remove path dependencies +# - Update CHANGELOG.md. +# - Create "tokio-util-0.7.x" git tag. +version = "0.7.7" +edition = "2018" +rust-version = "1.49" +authors = ["Tokio Contributors "] +license = "MIT" +repository = "https://github.com/tokio-rs/tokio" +homepage = "https://tokio.rs" +description = """ +Additional utilities for working with Tokio. +""" +categories = ["asynchronous"] + +[features] +# No features on by default +default = [] + +# Shorthand for enabling everything +full = ["codec", "compat", "io-util", "time", "net", "rt"] + +net = ["tokio/net"] +compat = ["futures-io",] +codec = ["tracing"] +time = ["tokio/time","slab"] +io = [] +io-util = ["io", "tokio/rt", "tokio/io-util"] +rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"] + +__docs_rs = ["futures-util"] + +[dependencies] +tokio = { version = "1.22.0", path = "../tokio", features = ["sync"] } +bytes = "1.0.0" +futures-core = "0.3.0" +futures-sink = "0.3.0" +futures-io = { version = "0.3.0", optional = true } +futures-util = { version = "0.3.0", optional = true } +pin-project-lite = "0.2.0" +slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` +tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } + +[target.'cfg(tokio_unstable)'.dependencies] +hashbrown = { version = "0.12.0", optional = true } + +[dev-dependencies] +tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } +tokio-test = { version = "0.4.0", path = "../tokio-test" } +tokio-stream = { version = "0.1", path = "../tokio-stream" } + +async-stream = "0.3.0" +futures = "0.3.0" +futures-test = "0.3.5" +parking_lot = "0.12.0" + +[package.metadata.docs.rs] +all-features = true +# enable unstable features in the documentation +rustdoc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"] +# it's necessary to _also_ pass `--cfg tokio_unstable` to rustc, or else +# dependencies will not be enabled, and the docs build will fail. +rustc-args = ["--cfg", "docsrs", "--cfg", "tokio_unstable"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8bdf6bd --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2023 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..28dbb73 --- /dev/null +++ b/METADATA @@ -0,0 +1,19 @@ +name: "tokio-util" +description: "Utilities for working with Tokio." +third_party { + url { + type: HOMEPAGE + value: "https://crates.io/crates/tokio-util" + } + url { + type: ARCHIVE + value: "https://static.crates.io/crates/tokio-util/tokio-util-0.7.7.crate" + } + version: "0.7.7" + license_type: NOTICE + last_upgrade_date { + year: 2023 + month: 3 + day: 3 + } +} diff --git a/MODULE_LICENSE_MIT b/MODULE_LICENSE_MIT new file mode 100644 index 0000000..e69de29 diff --git a/OWNERS b/OWNERS new file mode 100644 index 0000000..45dc4dd --- /dev/null +++ b/OWNERS @@ -0,0 +1 @@ +include platform/prebuilts/rust:master:/OWNERS diff --git a/README.md b/README.md new file mode 100644 index 0000000..0d74f36 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# tokio-util + +Utilities for working with Tokio. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in Tokio by you, shall be licensed as MIT, without any additional +terms or conditions. diff --git a/src/cfg.rs b/src/cfg.rs new file mode 100644 index 0000000..4035255 --- /dev/null +++ b/src/cfg.rs @@ -0,0 +1,71 @@ +macro_rules! cfg_codec { + ($($item:item)*) => { + $( + #[cfg(feature = "codec")] + #[cfg_attr(docsrs, doc(cfg(feature = "codec")))] + $item + )* + } +} + +macro_rules! cfg_compat { + ($($item:item)*) => { + $( + #[cfg(feature = "compat")] + #[cfg_attr(docsrs, doc(cfg(feature = "compat")))] + $item + )* + } +} + +macro_rules! cfg_net { + ($($item:item)*) => { + $( + #[cfg(all(feature = "net", feature = "codec"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))] + $item + )* + } +} + +macro_rules! cfg_io { + ($($item:item)*) => { + $( + #[cfg(feature = "io")] + #[cfg_attr(docsrs, doc(cfg(feature = "io")))] + $item + )* + } +} + +cfg_io! { + macro_rules! cfg_io_util { + ($($item:item)*) => { + $( + #[cfg(feature = "io-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + $item + )* + } + } +} + +macro_rules! cfg_rt { + ($($item:item)*) => { + $( + #[cfg(feature = "rt")] + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] + $item + )* + } +} + +macro_rules! cfg_time { + ($($item:item)*) => { + $( + #[cfg(feature = "time")] + #[cfg_attr(docsrs, doc(cfg(feature = "time")))] + $item + )* + } +} diff --git a/src/codec/any_delimiter_codec.rs b/src/codec/any_delimiter_codec.rs new file mode 100644 index 0000000..3dbfd45 --- /dev/null +++ b/src/codec/any_delimiter_codec.rs @@ -0,0 +1,263 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r"; +const DEFAULT_SEQUENCE_WRITER: &[u8] = b","; +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// Decode string of bytes containing various different delimiters. +/// +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// use tokio_util::codec::{AnyDelimiterCodec, Decoder}; +/// use bytes::{BufMut, BytesMut}; +/// +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec()); +/// let buf = &mut BytesMut::new(); +/// buf.reserve(200); +/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); +/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!("", codec.decode(buf).unwrap().unwrap()); +/// assert_eq!(None, codec.decode(buf).unwrap()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct AnyDelimiterCodec { + // Stored index of the next index to examine for the delimiter character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde}`, the method will + // only look at `de}` before returning. + next_index: usize, + + /// The maximum length for a given chunk. If `usize::MAX`, chunks will be + /// read until a delimiter character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a chunk which was over + /// the length limit? + is_discarding: bool, + + /// The bytes that are using for search during decode + seek_delimiters: Vec, + + /// The bytes that are using for encoding + sequence_writer: Vec, +} + +impl AnyDelimiterCodec { + /// Returns a `AnyDelimiterCodec` for splitting up data into chunks. + /// + /// # Note + /// + /// The returned `AnyDelimiterCodec` will not have an upper bound on the length + /// of a buffered chunk. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length() + pub fn new(seek_delimiters: Vec, sequence_writer: Vec) -> AnyDelimiterCodec { + AnyDelimiterCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + seek_delimiters, + sequence_writer, + } + } + + /// Returns a `AnyDelimiterCodec` with a maximum chunk length limit. + /// + /// If this is set, calls to `AnyDelimiterCodec::decode` will return a + /// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that chunk until a delimiter + /// character is reached, returning `None` until the delimiter over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the chunk currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any delimiter characters, causing unbounded memory consumption. + /// + /// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError + pub fn new_with_max_length( + seek_delimiters: Vec, + sequence_writer: Vec, + max_length: usize, + ) -> Self { + AnyDelimiterCodec { + max_length, + ..AnyDelimiterCodec::new(seek_delimiters, sequence_writer) + } + } + + /// Returns the maximum chunk length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec()); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::AnyDelimiterCodec; + /// + /// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +impl Decoder for AnyDelimiterCodec { + type Item = Bytes; + type Error = AnyDelimiterCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { + loop { + // Determine how far into the buffer we'll search for a delimiter. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let new_chunk_offset = buf[self.next_index..read_to].iter().position(|b| { + self.seek_delimiters + .iter() + .any(|delimiter| *b == *delimiter) + }); + + match (self.is_discarding, new_chunk_offset) { + (true, Some(offset)) => { + // If we found a new chunk, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a chunk normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a new chunk, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a new chunk. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a chunk! + let new_chunk_index = offset + self.next_index; + self.next_index = 0; + let mut chunk = buf.split_to(new_chunk_index + 1); + chunk.truncate(chunk.len() - 1); + let chunk = chunk.freeze(); + return Ok(Some(chunk)); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // new chunk, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded); + } + (false, None) => { + // We didn't find a chunk or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, AnyDelimiterCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // return remaining data, if any + if buf.is_empty() { + None + } else { + let chunk = buf.split_to(buf.len()); + self.next_index = 0; + Some(chunk.freeze()) + } + } + }) + } +} + +impl Encoder for AnyDelimiterCodec +where + T: AsRef, +{ + type Error = AnyDelimiterCodecError; + + fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> { + let chunk = chunk.as_ref(); + buf.reserve(chunk.len() + 1); + buf.put(chunk.as_bytes()); + buf.put(self.sequence_writer.as_ref()); + + Ok(()) + } +} + +impl Default for AnyDelimiterCodec { + fn default() -> Self { + Self::new( + DEFAULT_SEEK_DELIMITERS.to_vec(), + DEFAULT_SEQUENCE_WRITER.to_vec(), + ) + } +} + +/// An error occurred while encoding or decoding a chunk. +#[derive(Debug)] +pub enum AnyDelimiterCodecError { + /// The maximum chunk length was exceeded. + MaxChunkLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for AnyDelimiterCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AnyDelimiterCodecError::MaxChunkLengthExceeded => { + write!(f, "max chunk length exceeded") + } + AnyDelimiterCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From for AnyDelimiterCodecError { + fn from(e: io::Error) -> AnyDelimiterCodecError { + AnyDelimiterCodecError::Io(e) + } +} + +impl std::error::Error for AnyDelimiterCodecError {} diff --git a/src/codec/bytes_codec.rs b/src/codec/bytes_codec.rs new file mode 100644 index 0000000..ceab228 --- /dev/null +++ b/src/codec/bytes_codec.rs @@ -0,0 +1,86 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{BufMut, Bytes, BytesMut}; +use std::io; + +/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +/// +/// # Example +/// +/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`. +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`BytesMut`]: bytes::BytesMut +/// [`Error`]: std::io::Error +/// +/// ``` +/// # mod hidden { +/// # #[allow(unused_imports)] +/// use tokio::fs::File; +/// # } +/// use tokio::io::AsyncRead; +/// use tokio_util::codec::{FramedRead, BytesCodec}; +/// +/// # enum File {} +/// # impl File { +/// # async fn open(_name: &str) -> Result { +/// # use std::io::Cursor; +/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5])) +/// # } +/// # } +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> Result<(), std::io::Error> { +/// let my_async_read = File::open("filename.txt").await?; +/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new()); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] +pub struct BytesCodec(()); + +impl BytesCodec { + /// Creates a new `BytesCodec` for shipping around raw bytes. + pub fn new() -> BytesCodec { + BytesCodec(()) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { + if !buf.is_empty() { + let len = buf.len(); + Ok(Some(buf.split_to(len))) + } else { + Ok(None) + } + } +} + +impl Encoder for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} + +impl Encoder for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} diff --git a/src/codec/decoder.rs b/src/codec/decoder.rs new file mode 100644 index 0000000..84ecb22 --- /dev/null +++ b/src/codec/decoder.rs @@ -0,0 +1,184 @@ +use crate::codec::Framed; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use std::io; + +/// Decoding of frames via buffers. +/// +/// This trait is used when constructing an instance of [`Framed`] or +/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has +/// already been buffered in `src` and decodes the data into a stream of +/// `Self::Item` frames. +/// +/// Implementations are able to track state on `self`, which enables +/// implementing stateful streaming parsers. In many cases, though, this type +/// will simply be a unit struct (e.g. `struct HttpDecoder`). +/// +/// For some underlying data-sources, namely files and FIFOs, +/// it's possible to temporarily read 0 bytes by reaching EOF. +/// +/// In these cases `decode_eof` will be called until it signals +/// fulfillment of all closing frames by returning `Ok(None)`. +/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`] +/// will not invoke `decode` or `decode_eof` again, until data can be read +/// during a retry. +/// +/// It is up to the Decoder to keep track of a restart after an EOF, +/// and to decide how to handle such an event by, for example, +/// allowing frames to cross EOF boundaries, re-emitting opening frames, or +/// resetting the entire internal state. +/// +/// [`Framed`]: crate::codec::Framed +/// [`FramedRead`]: crate::codec::FramedRead +pub trait Decoder { + /// The type of decoded frames. + type Item; + + /// The type of unrecoverable frame decoding errors. + /// + /// If an individual message is ill-formed but can be ignored without + /// interfering with the processing of future messages, it may be more + /// useful to report the failure as an `Item`. + /// + /// `From` is required in the interest of making `Error` suitable + /// for returning directly from a [`FramedRead`], and to enable the default + /// implementation of `decode_eof` to yield an `io::Error` when the decoder + /// fails to consume all available data. + /// + /// Note that implementors of this trait can simply indicate `type Error = + /// io::Error` to use I/O errors as this type. + /// + /// [`FramedRead`]: crate::codec::FramedRead + type Error: From; + + /// Attempts to decode a frame from the provided buffer of bytes. + /// + /// This method is called by [`FramedRead`] whenever bytes are ready to be + /// parsed. The provided buffer of bytes is what's been read so far, and + /// this instance of `Decode` can determine whether an entire frame is in + /// the buffer and is ready to be returned. + /// + /// If an entire frame is available, then this instance will remove those + /// bytes from the buffer provided and return them as a decoded + /// frame. Note that removing bytes from the provided buffer doesn't always + /// necessarily copy the bytes, so this should be an efficient operation in + /// most circumstances. + /// + /// If the bytes look valid, but a frame isn't fully available yet, then + /// `Ok(None)` is returned. This indicates to the [`Framed`] instance that + /// it needs to read some more bytes before calling this method again. + /// + /// Note that the bytes provided may be empty. If a previous call to + /// `decode` consumed all the bytes in the buffer then `decode` will be + /// called again until it returns `Ok(None)`, indicating that more bytes need to + /// be read. + /// + /// Finally, if the bytes in the buffer are malformed then an error is + /// returned indicating why. This informs [`Framed`] that the stream is now + /// corrupt and should be terminated. + /// + /// [`Framed`]: crate::codec::Framed + /// [`FramedRead`]: crate::codec::FramedRead + /// + /// # Buffer management + /// + /// Before returning from the function, implementations should ensure that + /// the buffer has appropriate capacity in anticipation of future calls to + /// `decode`. Failing to do so leads to inefficiency. + /// + /// For example, if frames have a fixed length, or if the length of the + /// current frame is known from a header, a possible buffer management + /// strategy is: + /// + /// ```no_run + /// # use std::io; + /// # + /// # use bytes::BytesMut; + /// # use tokio_util::codec::Decoder; + /// # + /// # struct MyCodec; + /// # + /// impl Decoder for MyCodec { + /// // ... + /// # type Item = BytesMut; + /// # type Error = io::Error; + /// + /// fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + /// // ... + /// + /// // Reserve enough to complete decoding of the current frame. + /// let current_frame_len: usize = 1000; // Example. + /// // And to start decoding the next frame. + /// let next_frame_header_len: usize = 10; // Example. + /// src.reserve(current_frame_len + next_frame_header_len); + /// + /// return Ok(None); + /// } + /// } + /// ``` + /// + /// An optimal buffer management strategy minimizes reallocations and + /// over-allocations. + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; + + /// A default method available to be called when there are no more bytes + /// available to be read from the underlying I/O. + /// + /// This method defaults to calling `decode` and returns an error if + /// `Ok(None)` is returned while there is unconsumed data in `buf`. + /// Typically this doesn't need to be implemented unless the framing + /// protocol differs near the end of the stream, or if you need to construct + /// frames _across_ eof boundaries on sources that can be resumed. + /// + /// Note that the `buf` argument may be empty. If a previous call to + /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be + /// called again until it returns `None`, indicating that there are no more + /// frames to yield. This behavior enables returning finalization frames + /// that may not be based on inbound data. + /// + /// Once `None` has been returned, `decode_eof` won't be called again until + /// an attempt to resume the stream has been made, where the underlying stream + /// actually returned more data. + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the [`Framed`] returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Framed`]: crate::codec::Framed + fn framed(self, io: T) -> Framed + where + Self: Sized, + { + Framed::new(io, self) + } +} diff --git a/src/codec/encoder.rs b/src/codec/encoder.rs new file mode 100644 index 0000000..770a10f --- /dev/null +++ b/src/codec/encoder.rs @@ -0,0 +1,25 @@ +use bytes::BytesMut; +use std::io; + +/// Trait of helper objects to write out messages as bytes, for use with +/// [`FramedWrite`]. +/// +/// [`FramedWrite`]: crate::codec::FramedWrite +pub trait Encoder { + /// The type of encoding errors. + /// + /// [`FramedWrite`] requires `Encoder`s errors to implement `From` + /// in the interest letting it return `Error`s directly. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + type Error: From; + + /// Encodes a frame into the buffer provided. + /// + /// This method will encode `item` into the byte buffer provided by `dst`. + /// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and + /// will be written out when possible. + /// + /// [`FramedWrite`]: crate::codec::FramedWrite + fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/src/codec/framed.rs b/src/codec/framed.rs new file mode 100644 index 0000000..8a344f9 --- /dev/null +++ b/src/codec/framed.rs @@ -0,0 +1,383 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using + /// the `Encoder` and `Decoder` traits to encode and decode frames. + /// + /// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or + /// by using the `new` function seen below. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Decoder::framed`]: crate::codec::Decoder::framed() + pub struct Framed { + #[pin] + inner: FramedImpl + } +} + +impl Framed +where + T: AsyncRead + AsyncWrite, +{ + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// Note that, for some byte sources, the stream can be resumed after an EOF + /// by reading from it, even after it has returned `None`. Repeated attempts + /// to do so, without new data available, continue to return `None` without + /// creating more (closing) frames. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn new(inner: T, codec: U) -> Framed { + Framed { + inner: FramedImpl { + inner, + codec, + state: Default::default(), + }, + } + } + + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data, + /// with a specific read buffer initial capacity. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the codec + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decode`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed { + Framed { + inner: FramedImpl { + inner, + codec, + state: RWFrames { + read: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + write: WriteFrame::default(), + }, + }, + } + } +} + +impl Framed { + /// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this + /// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both [`Stream`] and + /// [`Sink`]; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// This objects takes a stream and a readbuffer and a writebuffer. These field + /// can be obtained from an existing `Framed` with the [`into_parts`] method. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling [`split`] on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`Decoder`]: crate::codec::Decoder + /// [`Encoder`]: crate::codec::Encoder + /// [`into_parts`]: crate::codec::Framed::into_parts() + /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split + pub fn from_parts(parts: FramedParts) -> Framed { + Framed { + inner: FramedImpl { + inner: parts.io, + codec: parts.codec, + state: RWFrames { + read: parts.read_buf.into(), + write: parts.write_buf.into(), + }, + }, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &U { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut U { + &mut self.inner.codec + } + + /// Maps the codec `U` to `C`, preserving the read and write buffers + /// wrapped by `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn map_codec(self, map: F) -> Framed + where + F: FnOnce(U) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let parts = self.into_parts(); + Framed::from_parts(FramedParts { + io: parts.io, + codec: map(parts.codec), + read_buf: parts.read_buf, + write_buf: parts.write_buf, + _priv: (), + }) + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.read.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.read.buffer + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.write.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.write.buffer + } + + /// Returns backpressure boundary + pub fn backpressure_boundary(&self) -> usize { + self.inner.state.write.backpressure_boundary + } + + /// Updates backpressure boundary + pub fn set_backpressure_boundary(&mut self, boundary: usize) { + self.inner.state.write.backpressure_boundary = boundary; + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Consumes the `Framed`, returning its underlying I/O stream, the buffer + /// with unprocessed data, and the codec. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_parts(self) -> FramedParts { + FramedParts { + io: self.inner.inner, + codec: self.inner.codec, + read_buf: self.inner.state.read.buffer, + write_buf: self.inner.state.write.buffer, + _priv: (), + } + } +} + +// This impl just defers to the underlying FramedImpl +impl Stream for Framed +where + T: AsyncRead, + U: Decoder, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying FramedImpl +impl Sink for Framed +where + T: AsyncWrite, + U: Encoder, + U::Error: From, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +impl fmt::Debug for Framed +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Framed") + .field("io", self.get_ref()) + .field("codec", self.codec()) + .finish() + } +} + +/// `FramedParts` contains an export of the data of a Framed transport. +/// It can be used to construct a new [`Framed`] with a different codec. +/// It contains all current buffers and the inner transport. +/// +/// [`Framed`]: crate::codec::Framed +#[derive(Debug)] +#[allow(clippy::manual_non_exhaustive)] +pub struct FramedParts { + /// The inner transport used to read bytes to and write bytes to + pub io: T, + + /// The codec + pub codec: U, + + /// The buffer with read but unprocessed data. + pub read_buf: BytesMut, + + /// A buffer with unprocessed data which are not written yet. + pub write_buf: BytesMut, + + /// This private field allows us to add additional fields in the future in a + /// backwards compatible way. + _priv: (), +} + +impl FramedParts { + /// Create a new, default, `FramedParts` + pub fn new(io: T, codec: U) -> FramedParts + where + U: Encoder, + { + FramedParts { + io, + codec, + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + _priv: (), + } + } +} diff --git a/src/codec/framed_impl.rs b/src/codec/framed_impl.rs new file mode 100644 index 0000000..8f3fa49 --- /dev/null +++ b/src/codec/framed_impl.rs @@ -0,0 +1,312 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use futures_core::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use futures_core::ready; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::borrow::{Borrow, BorrowMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tracing::trace; + +pin_project! { + #[derive(Debug)] + pub(crate) struct FramedImpl { + #[pin] + pub(crate) inner: T, + pub(crate) state: State, + pub(crate) codec: U, + } +} + +const INITIAL_CAPACITY: usize = 8 * 1024; + +#[derive(Debug)] +pub(crate) struct ReadFrame { + pub(crate) eof: bool, + pub(crate) is_readable: bool, + pub(crate) buffer: BytesMut, + pub(crate) has_errored: bool, +} + +pub(crate) struct WriteFrame { + pub(crate) buffer: BytesMut, + pub(crate) backpressure_boundary: usize, +} + +#[derive(Default)] +pub(crate) struct RWFrames { + pub(crate) read: ReadFrame, + pub(crate) write: WriteFrame, +} + +impl Default for ReadFrame { + fn default() -> Self { + Self { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + has_errored: false, + } + } +} + +impl Default for WriteFrame { + fn default() -> Self { + Self { + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + backpressure_boundary: INITIAL_CAPACITY, + } + } +} + +impl From for ReadFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + is_readable: size > 0, + eof: false, + has_errored: false, + } + } +} + +impl From for WriteFrame { + fn from(mut buffer: BytesMut) -> Self { + let size = buffer.capacity(); + if size < INITIAL_CAPACITY { + buffer.reserve(INITIAL_CAPACITY - size); + } + + Self { + buffer, + backpressure_boundary: INITIAL_CAPACITY, + } + } +} + +impl Borrow for RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow for RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut for RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} +impl Stream for FramedImpl +where + T: AsyncRead, + U: Decoder, + R: BorrowMut, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_read_buf; + + let mut pinned = self.project(); + let state: &mut ReadFrame = pinned.state.borrow_mut(); + // The following loops implements a state machine with each state corresponding + // to a combination of the `is_readable` and `eof` flags. States persist across + // loop entries and most state transitions occur with a return. + // + // The initial state is `reading`. + // + // | state | eof | is_readable | has_errored | + // |---------|-------|-------------|-------------| + // | reading | false | false | false | + // | framing | false | true | false | + // | pausing | true | true | false | + // | paused | true | false | false | + // | errored | | | true | + // `decode_eof` returns Err + // ┌────────────────────────────────────────────────────────┐ + // `decode_eof` returns │ │ + // `Ok(Some)` │ │ + // ┌─────┐ │ `decode_eof` returns After returning │ + // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ + // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ + // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ + // Pending read │ │ │ │ │ │ + // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ + // │ │ │ ┌──────┐ │ Pending │ │ + // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ + // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ + // └──┬─▲────┘ └─────┬──┬┘ │ │ + // │ │ │ │ `decode` returns Err │ │ + // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ + // │ read returns Err │ + // └────────────────────────────────────────────────────────────────────────────────────────────┘ + loop { + // Return `None` if we have encountered an error from the underlying decoder + // See: https://github.com/tokio-rs/tokio/issues/3976 + if state.has_errored { + // preparing has_errored -> paused + trace!("Returning None and setting paused"); + state.is_readable = false; + state.has_errored = false; + return Poll::Ready(None); + } + + // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", + // i.e. it _might_ contain data consumable as a frame or closing frame. + // Both signal that there is no such data by returning `None`. + // + // If `decode` couldn't read a frame and the upstream source has returned eof, + // `decode_eof` will attempt to decode the remaining bytes as closing frames. + // + // If the underlying AsyncRead is resumable, we may continue after an EOF, + // but must finish emitting all of it's associated `decode_eof` frames. + // Furthermore, we don't want to emit any `decode_eof` frames on retried + // reads after an EOF unless we've actually read more data. + if state.is_readable { + // pausing or framing + if state.eof { + // pausing + let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + })?; + if frame.is_none() { + state.is_readable = false; // prepare pausing -> paused + } + // implicit pausing -> pausing or pausing -> paused + return Poll::Ready(frame.map(Ok)); + } + + // framing + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + op + })? { + trace!("frame decoded from buffer"); + // implicit framing -> framing + return Poll::Ready(Some(Ok(frame))); + } + + // framing -> reading + state.is_readable = false; + } + // reading or paused + // If we can't build a frame yet, try to read more data and try again. + // Make sure we've got room for at least one byte to read to ensure + // that we don't get a spurious 0 that looks like EOF. + state.buffer.reserve(1); + let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( + |err| { + trace!("Got an error, going to errored state"); + state.has_errored = true; + err + }, + )? { + Poll::Ready(ct) => ct, + // implicit reading -> reading or implicit paused -> paused + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + if state.eof { + // We're already at an EOF, and since we've reached this path + // we're also not readable. This implies that we've already finished + // our `decode_eof` handling, so we can simply return `None`. + // implicit paused -> paused + return Poll::Ready(None); + } + // prepare reading -> paused + state.eof = true; + } else { + // prepare paused -> framing or noop reading -> framing + state.eof = false; + } + + // paused -> framing or reading -> framing or reading -> pausing + state.is_readable = true; + } + } +} + +impl Sink for FramedImpl +where + T: AsyncWrite, + U: Encoder, + U::Error: From, + W: BorrowMut, +{ + type Error = U::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { + self.as_mut().poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned + .codec + .encode(item, &mut pinned.state.borrow_mut().buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_write_buf; + trace!("flushing framed transport"); + let mut pinned = self.project(); + + while !pinned.state.borrow_mut().buffer.is_empty() { + let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); + trace!(remaining = buffer.len(), "writing;"); + + let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + } + + // Try flushing the underlying IO + ready!(pinned.inner.poll_flush(cx))?; + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + ready!(self.project().inner.poll_shutdown(cx))?; + + Poll::Ready(Ok(())) + } +} diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs new file mode 100644 index 0000000..184c567 --- /dev/null +++ b/src/codec/framed_read.rs @@ -0,0 +1,199 @@ +use crate::codec::framed_impl::{FramedImpl, ReadFrame}; +use crate::codec::Decoder; + +use futures_core::Stream; +use tokio::io::AsyncRead; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Stream`] of messages decoded from an [`AsyncRead`]. + /// + /// [`Stream`]: futures_core::Stream + /// [`AsyncRead`]: tokio::io::AsyncRead + pub struct FramedRead { + #[pin] + inner: FramedImpl, + } +} + +// ===== impl FramedRead ===== + +impl FramedRead +where + T: AsyncRead, + D: Decoder, +{ + /// Creates a new `FramedRead` with the given `decoder`. + pub fn new(inner: T, decoder: D) -> FramedRead { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: Default::default(), + }, + } + } + + /// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity` + /// initial size. + pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead { + FramedRead { + inner: FramedImpl { + inner, + codec: decoder, + state: ReadFrame { + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(capacity), + has_errored: false, + }, + }, + } + } +} + +impl FramedRead { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedRead`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying decoder. + pub fn decoder(&self) -> &D { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_mut(&mut self) -> &mut D { + &mut self.inner.codec + } + + /// Maps the decoder `D` to `C`, preserving the read buffer + /// wrapped by `Framed`. + pub fn map_decoder(self, map: F) -> FramedRead + where + F: FnOnce(D) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedRead { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { + self.project().inner.project().codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } +} + +// This impl just defers to the underlying FramedImpl +impl Stream for FramedRead +where + T: AsyncRead, + D: Decoder, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl Sink for FramedRead +where + T: Sink, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_close(cx) + } +} + +impl fmt::Debug for FramedRead +where + T: fmt::Debug, + D: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedRead") + .field("inner", &self.get_ref()) + .field("decoder", &self.decoder()) + .field("eof", &self.inner.state.eof) + .field("is_readable", &self.inner.state.is_readable) + .field("buffer", &self.read_buffer()) + .finish() + } +} diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs new file mode 100644 index 0000000..3f0a340 --- /dev/null +++ b/src/codec/framed_write.rs @@ -0,0 +1,188 @@ +use crate::codec::encoder::Encoder; +use crate::codec::framed_impl::{FramedImpl, WriteFrame}; + +use futures_core::Stream; +use tokio::io::AsyncWrite; + +use bytes::BytesMut; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::fmt; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A [`Sink`] of frames encoded to an `AsyncWrite`. + /// + /// [`Sink`]: futures_sink::Sink + pub struct FramedWrite { + #[pin] + inner: FramedImpl, + } +} + +impl FramedWrite +where + T: AsyncWrite, +{ + /// Creates a new `FramedWrite` with the given `encoder`. + pub fn new(inner: T, encoder: E) -> FramedWrite { + FramedWrite { + inner: FramedImpl { + inner, + codec: encoder, + state: WriteFrame::default(), + }, + } + } +} + +impl FramedWrite { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner + } + + /// Returns a pinned mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner.project().inner + } + + /// Consumes the `FramedWrite`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner + } + + /// Returns a reference to the underlying encoder. + pub fn encoder(&self) -> &E { + &self.inner.codec + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_mut(&mut self) -> &mut E { + &mut self.inner.codec + } + + /// Maps the encoder `E` to `C`, preserving the write buffer + /// wrapped by `Framed`. + pub fn map_encoder(self, map: F) -> FramedWrite + where + F: FnOnce(E) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = self.inner; + FramedWrite { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + + /// Returns a mutable reference to the underlying encoder. + pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { + self.project().inner.project().codec + } + + /// Returns a reference to the write buffer. + pub fn write_buffer(&self) -> &BytesMut { + &self.inner.state.buffer + } + + /// Returns a mutable reference to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.inner.state.buffer + } + + /// Returns backpressure boundary + pub fn backpressure_boundary(&self) -> usize { + self.inner.state.backpressure_boundary + } + + /// Updates backpressure boundary + pub fn set_backpressure_boundary(&mut self, boundary: usize) { + self.inner.state.backpressure_boundary = boundary; + } +} + +// This impl just defers to the underlying FramedImpl +impl Sink for FramedWrite +where + T: AsyncWrite, + E: Encoder, + E::Error: From, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +// This impl just defers to the underlying T: Stream +impl Stream for FramedWrite +where + T: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_next(cx) + } +} + +impl fmt::Debug for FramedWrite +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedWrite") + .field("inner", &self.get_ref()) + .field("encoder", &self.encoder()) + .field("buffer", &self.inner.state.buffer) + .finish() + } +} diff --git a/src/codec/length_delimited.rs b/src/codec/length_delimited.rs new file mode 100644 index 0000000..a182dca --- /dev/null +++ b/src/codec/length_delimited.rs @@ -0,0 +1,1043 @@ +//! Frame a stream of bytes based on a length prefix +//! +//! Many protocols delimit their frames by prefacing frame data with a +//! frame head that specifies the length of the frame. The +//! `length_delimited` module provides utilities for handling the length +//! based framing. This allows the consumer to work with entire frames +//! without having to worry about buffering or other framing logic. +//! +//! # Getting started +//! +//! If implementing a protocol from scratch, using length delimited framing +//! is an easy way to get started. [`LengthDelimitedCodec::new()`] will +//! return a length delimited codec using default configuration values. +//! This can then be used to construct a framer to adapt a full-duplex +//! byte stream into a stream of frames. +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! fn bind_transport(io: T) +//! -> Framed +//! { +//! Framed::new(io, LengthDelimitedCodec::new()) +//! } +//! # pub fn main() {} +//! ``` +//! +//! The returned transport implements `Sink + Stream` for `BytesMut`. It +//! encodes the frame with a big-endian `u32` header denoting the frame +//! payload length: +//! +//! ```text +//! +----------+--------------------------------+ +//! | len: u32 | frame payload | +//! +----------+--------------------------------+ +//! ``` +//! +//! Specifically, given the following: +//! +//! ``` +//! use tokio::io::{AsyncRead, AsyncWrite}; +//! use tokio_util::codec::{Framed, LengthDelimitedCodec}; +//! +//! use futures::SinkExt; +//! use bytes::Bytes; +//! +//! async fn write_frame(io: T) -> Result<(), Box> +//! where +//! T: AsyncRead + AsyncWrite + Unpin, +//! { +//! let mut transport = Framed::new(io, LengthDelimitedCodec::new()); +//! let frame = Bytes::from("hello world"); +//! +//! transport.send(frame).await?; +//! Ok(()) +//! } +//! ``` +//! +//! The encoded frame will look like this: +//! +//! ```text +//! +---- len: u32 ----+---- data ----+ +//! | \x00\x00\x00\x0b | hello world | +//! +------------------+--------------+ +//! ``` +//! +//! # Decoding +//! +//! [`FramedRead`] adapts an [`AsyncRead`] into a `Stream` of [`BytesMut`], +//! such that each yielded [`BytesMut`] value contains the contents of an +//! entire frame. There are many configuration parameters enabling +//! [`FramedRead`] to handle a wide range of protocols. Here are some +//! examples that will cover the various options at a high level. +//! +//! ## Example 1 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::() +//! .length_adjustment(0) // default value +//! .num_skip(0) // Do not strip frame header +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0B | Hello world | --> | \x00\x0B | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! The value of the length field is 11 (`\x0B`) which represents the length +//! of the payload, `hello world`. By default, [`FramedRead`] assumes that +//! the length field represents the number of bytes that **follows** the +//! length field. Thus, the entire frame has a length of 13: 2 bytes for the +//! frame head + 11 bytes for the payload. +//! +//! ## Example 2 +//! +//! The following will parse a `u16` length field at offset 0, omitting the +//! frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::() +//! .length_adjustment(0) // default value +//! // `num_skip` is not needed, the default is to skip +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +--- Payload ---+ +//! | \x00\x0B | Hello world | --> | Hello world | +//! +----------+---------------+ +---------------+ +//! ``` +//! +//! This is similar to the first example, the only difference is that the +//! frame head is **not** included in the yielded `BytesMut` value. +//! +//! ## Example 3 +//! +//! The following will parse a `u16` length field at offset 0, including the +//! frame head in the yielded `BytesMut`. In this case, the length field +//! **includes** the frame head length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_type::() +//! .length_adjustment(-2) // size of head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +-- len ---+--- Payload ---+ +-- len ---+--- Payload ---+ +//! | \x00\x0D | Hello world | --> | \x00\x0D | Hello world | +//! +----------+---------------+ +----------+---------------+ +//! ``` +//! +//! In most cases, the length field represents the length of the payload +//! only, as shown in the previous examples. However, in some protocols the +//! length field represents the length of the whole frame, including the +//! head. In such cases, we specify a negative `length_adjustment` to adjust +//! the value provided in the frame head to represent the payload length. +//! +//! ## Example 4 +//! +//! The following will parse a 3 byte length field at offset 0 in a 5 byte +//! frame head, including the frame head in the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(2) // remaining head +//! .num_skip(0) +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! +//! DECODED +//! +---- len -----+- head -+--- Payload ---+ +//! | \x00\x00\x0B | \xCAFE | Hello world | +//! +--------------+--------+---------------+ +//! ``` +//! +//! A more advanced example that shows a case where there is extra frame +//! head data between the length field and the payload. In such cases, it is +//! usually desirable to include the frame head as part of the yielded +//! `BytesMut`. This lets consumers of the length delimited framer to +//! process the frame head as needed. +//! +//! The positive `length_adjustment` value lets `FramedRead` factor in the +//! additional head into the frame length calculation. +//! +//! ## Example 5 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::() +//! .length_adjustment(1) // length of hdr2 +//! .num_skip(3) // length of hdr1 + LEN +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0B | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! The length field is situated in the middle of the frame head. In this +//! case, the first byte in the frame head could be a version or some other +//! identifier that is not needed for processing. On the other hand, the +//! second half of the head is needed. +//! +//! `length_field_offset` indicates how many bytes to skip before starting +//! to read the length field. `length_adjustment` is the number of bytes to +//! skip starting at the end of the length field. In this case, it is the +//! second half of the head. +//! +//! ## Example 6 +//! +//! The following will parse a `u16` length field at offset 1 of a 4 byte +//! frame head. The first byte and the length field will be omitted from the +//! yielded `BytesMut`, but the trailing 2 bytes of the frame head will be +//! included. In this case, the length field **includes** the frame head +//! length. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(1) // length of hdr1 +//! .length_field_type::() +//! .length_adjustment(-3) // length of hdr1 + LEN, negative +//! .num_skip(3) +//! .new_read(io); +//! # } +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT +//! +- hdr1 -+-- len ---+- hdr2 -+--- Payload ---+ +//! | \xCA | \x00\x0F | \xFE | Hello world | +//! +--------+----------+--------+---------------+ +//! +//! DECODED +//! +- hdr2 -+--- Payload ---+ +//! | \xFE | Hello world | +//! +--------+---------------+ +//! ``` +//! +//! Similar to the example above, the difference is that the length field +//! represents the length of the entire frame instead of just the payload. +//! The length of `hdr1` and `len` must be counted in `length_adjustment`. +//! Note that the length of `hdr2` does **not** need to be explicitly set +//! anywhere because it already is factored into the total frame length that +//! is read from the byte stream. +//! +//! ## Example 7 +//! +//! The following will parse a 3 byte length field at offset 0 in a 4 byte +//! frame head, excluding the 4th byte from the yielded `BytesMut`. +//! +//! ``` +//! # use tokio::io::AsyncRead; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn bind_read(io: T) { +//! LengthDelimitedCodec::builder() +//! .length_field_offset(0) // default value +//! .length_field_length(3) +//! .length_adjustment(0) // default value +//! .num_skip(4) // skip the first 4 bytes +//! .new_read(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! The following frame will be decoded as such: +//! +//! ```text +//! INPUT DECODED +//! +------- len ------+--- Payload ---+ +--- Payload ---+ +//! | \x00\x00\x0B\xFF | Hello world | => | Hello world | +//! +------------------+---------------+ +---------------+ +//! ``` +//! +//! A simple example where there are unused bytes between the length field +//! and the payload. +//! +//! # Encoding +//! +//! [`FramedWrite`] adapts an [`AsyncWrite`] into a `Sink` of [`BytesMut`], +//! such that each submitted [`BytesMut`] is prefaced by a length field. +//! There are fewer configuration options than [`FramedRead`]. Given +//! protocols that have more complex frame heads, an encoder should probably +//! be written by hand using [`Encoder`]. +//! +//! Here is a simple example, given a `FramedWrite` with the following +//! configuration: +//! +//! ``` +//! # use tokio::io::AsyncWrite; +//! # use tokio_util::codec::LengthDelimitedCodec; +//! # fn write_frame(io: T) { +//! # let _ = +//! LengthDelimitedCodec::builder() +//! .length_field_type::() +//! .new_write(io); +//! # } +//! # pub fn main() {} +//! ``` +//! +//! A payload of `hello world` will be encoded as: +//! +//! ```text +//! +- len: u16 -+---- data ----+ +//! | \x00\x0b | hello world | +//! +------------+--------------+ +//! ``` +//! +//! [`LengthDelimitedCodec::new()`]: method@LengthDelimitedCodec::new +//! [`FramedRead`]: struct@FramedRead +//! [`FramedWrite`]: struct@FramedWrite +//! [`AsyncRead`]: trait@tokio::io::AsyncRead +//! [`AsyncWrite`]: trait@tokio::io::AsyncWrite +//! [`Encoder`]: trait@Encoder +//! [`BytesMut`]: bytes::BytesMut + +use crate::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::error::Error as StdError; +use std::io::{self, Cursor}; +use std::{cmp, fmt, mem}; + +/// Configure length delimited `LengthDelimitedCodec`s. +/// +/// `Builder` enables constructing configured length delimited codecs. Note +/// that not all configuration settings apply to both encoding and decoding. See +/// the documentation for specific methods for more detail. +#[derive(Debug, Clone, Copy)] +pub struct Builder { + // Maximum frame length + max_frame_len: usize, + + // Number of bytes representing the field length + length_field_len: usize, + + // Number of bytes in the header before the length field + length_field_offset: usize, + + // Adjust the length specified in the header field by this amount + length_adjustment: isize, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: Option, + + // Length field byte order (little or big endian) + length_field_is_big_endian: bool, +} + +/// An error when the number of bytes read is more than max frame length. +pub struct LengthDelimitedCodecError { + _priv: (), +} + +/// A codec for frames delimited by a frame head specifying their lengths. +/// +/// This allows the consumer to work with entire frames without having to worry +/// about buffering or other framing logic. +/// +/// See [module level] documentation for more detail. +/// +/// [module level]: index.html +#[derive(Debug, Clone)] +pub struct LengthDelimitedCodec { + // Configuration values + builder: Builder, + + // Read state + state: DecodeState, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +// ===== impl LengthDelimitedCodec ====== + +impl LengthDelimitedCodec { + /// Creates a new `LengthDelimitedCodec` with the default configuration values. + pub fn new() -> Self { + Self { + builder: Builder::new(), + state: DecodeState::Head, + } + } + + /// Creates a new length delimited codec builder with default configuration + /// values. + pub fn builder() -> Builder { + Builder::new() + } + + /// Returns the current max frame setting + /// + /// This is the largest size this codec will accept from the wire. Larger + /// frames will be rejected. + pub fn max_frame_length(&self) -> usize { + self.builder.max_frame_len + } + + /// Updates the max frame setting. + /// + /// The change takes effect the next time a frame is decoded. In other + /// words, if a frame is currently in process of being decoded with a frame + /// size greater than `val` but less than the max frame length in effect + /// before calling this function, then the frame will be allowed. + pub fn set_max_frame_length(&mut self, val: usize) { + self.builder.max_frame_length(val); + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result> { + let head_len = self.builder.num_head_bytes(); + let field_len = self.builder.length_field_len; + + if src.len() < head_len { + // Not enough data + return Ok(None); + } + + let n = { + let mut src = Cursor::new(&mut *src); + + // Skip the required bytes + src.advance(self.builder.length_field_offset); + + // match endianness + let n = if self.builder.length_field_is_big_endian { + src.get_uint(field_len) + } else { + src.get_uint_le(field_len) + }; + + if n > self.builder.max_frame_len as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // The check above ensures there is no overflow + let n = n as usize; + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_sub(-self.builder.length_adjustment as usize) + } else { + n.checked_add(self.builder.length_adjustment as usize) + }; + + // Error handling + match n { + Some(n) => n, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + )); + } + } + }; + + src.advance(self.builder.get_num_skip()); + + // Ensure that the buffer has enough space to read the incoming + // payload + src.reserve(n.saturating_sub(src.len())); + + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> Option { + // At this point, the buffer has already had the required capacity + // reserved. All there is to do is read. + if src.len() < n { + return None; + } + + Some(src.split_to(n)) + } +} + +impl Decoder for LengthDelimitedCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> io::Result> { + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src) { + Some(data) => { + // Update the decode state + self.state = DecodeState::Head; + + // Make sure the buffer has enough space to read the next head + src.reserve(self.builder.num_head_bytes().saturating_sub(src.len())); + + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder for LengthDelimitedCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> { + let n = data.len(); + + if n > self.builder.max_frame_len { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + LengthDelimitedCodecError { _priv: () }, + )); + } + + // Adjust `n` with bounds checking + let n = if self.builder.length_adjustment < 0 { + n.checked_add(-self.builder.length_adjustment as usize) + } else { + n.checked_sub(self.builder.length_adjustment as usize) + }; + + let n = n.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "provided length would overflow after adjustment", + ) + })?; + + // Reserve capacity in the destination buffer to fit the frame and + // length field (plus adjustment). + dst.reserve(self.builder.length_field_len + n); + + if self.builder.length_field_is_big_endian { + dst.put_uint(n as u64, self.builder.length_field_len); + } else { + dst.put_uint_le(n as u64, self.builder.length_field_len); + } + + // Write the frame to the buffer + dst.extend_from_slice(&data[..]); + + Ok(()) + } +} + +impl Default for LengthDelimitedCodec { + fn default() -> Self { + Self::new() + } +} + +// ===== impl Builder ===== + +mod builder { + /// Types that can be used with `Builder::length_field_type`. + pub trait LengthFieldType {} + + impl LengthFieldType for u8 {} + impl LengthFieldType for u16 {} + impl LengthFieldType for u32 {} + impl LengthFieldType for u64 {} + + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32", + target_pointer_width = "64", + ))] + impl LengthFieldType for usize {} +} + +impl Builder { + /// Creates a new length delimited codec builder with default configuration + /// values. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new() -> Builder { + Builder { + // Default max frame length of 8MB + max_frame_len: 8 * 1_024 * 1_024, + + // Default byte length of 4 + length_field_len: 4, + + // Default to the header field being at the start of the header. + length_field_offset: 0, + + length_adjustment: 0, + + // Total number of bytes to skip before reading the payload, if not set, + // `length_field_len + length_field_offset` + num_skip: None, + + // Default to reading the length field in network (big) endian. + length_field_is_big_endian: true, + } + } + + /// Read the length field as a big endian integer + /// + /// This is the default setting. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .big_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn big_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = true; + self + } + + /// Read the length field as a little endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .little_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn little_endian(&mut self) -> &mut Self { + self.length_field_is_big_endian = false; + self + } + + /// Read the length field as a native endian integer + /// + /// The default setting is big endian. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .native_endian() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn native_endian(&mut self) -> &mut Self { + if cfg!(target_endian = "big") { + self.big_endian() + } else { + self.little_endian() + } + } + + /// Sets the max frame length in bytes + /// + /// This configuration option applies to both encoding and decoding. The + /// default value is 8MB. + /// + /// When decoding, the length field read from the byte stream is checked + /// against this setting **before** any adjustments are applied. When + /// encoding, the length of the submitted payload is checked against this + /// setting. + /// + /// When frames exceed the max length, an `io::Error` with the custom value + /// of the `LengthDelimitedCodecError` type will be returned. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .max_frame_length(8 * 1024 * 1024) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn max_frame_length(&mut self, val: usize) -> &mut Self { + self.max_frame_len = val; + self + } + + /// Sets the unsigned integer type used to represent the length field. + /// + /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on + /// 64-bit targets). + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + /// + /// Unlike [`Builder::length_field_length`], this does not fail at runtime + /// and instead produces a compile error: + /// + /// ```compile_fail + /// # use tokio::io::AsyncRead; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_type(&mut self) -> &mut Self { + self.length_field_length(mem::size_of::()) + } + + /// Sets the number of bytes used to represent the length field + /// + /// The default value is `4`. The max value is `8`. + /// + /// This configuration option applies to both encoding and decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_length(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_length(&mut self, val: usize) -> &mut Self { + assert!(val > 0 && val <= 8, "invalid length field length"); + self.length_field_len = val; + self + } + + /// Sets the number of bytes in the header before the length field + /// + /// This configuration option only applies to decoding. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(1) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_offset(&mut self, val: usize) -> &mut Self { + self.length_field_offset = val; + self + } + + /// Delta between the payload length specified in the header and the real + /// payload length + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_adjustment(-2) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_adjustment(&mut self, val: isize) -> &mut Self { + self.length_adjustment = val; + self + } + + /// Sets the number of bytes to skip before reading the payload + /// + /// Default value is `length_field_len + length_field_offset` + /// + /// This configuration option only applies to decoding + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .num_skip(4) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn num_skip(&mut self, val: usize) -> &mut Self { + self.num_skip = Some(val); + self + } + + /// Create a configured length delimited `LengthDelimitedCodec` + /// + /// # Examples + /// + /// ``` + /// use tokio_util::codec::LengthDelimitedCodec; + /// # pub fn main() { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_codec(); + /// # } + /// ``` + pub fn new_codec(&self) -> LengthDelimitedCodec { + LengthDelimitedCodec { + builder: *self, + state: DecodeState::Head, + } + } + + /// Create a configured length delimited `FramedRead` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_offset(0) + /// .length_field_type::() + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_read(&self, upstream: T) -> FramedRead + where + T: AsyncRead, + { + FramedRead::new(upstream, self.new_codec()) + } + + /// Create a configured length delimited `FramedWrite` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncWrite; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_write(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_write(&self, inner: T) -> FramedWrite + where + T: AsyncWrite, + { + FramedWrite::new(inner, self.new_codec()) + } + + /// Create a configured length delimited `Framed` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn write_frame(io: T) { + /// # let _ = + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_framed(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_framed(&self, inner: T) -> Framed + where + T: AsyncRead + AsyncWrite, + { + Framed::new(inner, self.new_codec()) + } + + fn num_head_bytes(&self) -> usize { + let num = self.length_field_offset + self.length_field_len; + cmp::max(num, self.num_skip.unwrap_or(0)) + } + + fn get_num_skip(&self) -> usize { + self.num_skip + .unwrap_or(self.length_field_offset + self.length_field_len) + } +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +// ===== impl LengthDelimitedCodecError ===== + +impl fmt::Debug for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LengthDelimitedCodecError").finish() + } +} + +impl fmt::Display for LengthDelimitedCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("frame size too big") + } +} + +impl StdError for LengthDelimitedCodecError {} diff --git a/src/codec/lines_codec.rs b/src/codec/lines_codec.rs new file mode 100644 index 0000000..7a0a8f0 --- /dev/null +++ b/src/codec/lines_codec.rs @@ -0,0 +1,230 @@ +use crate::codec::decoder::Decoder; +use crate::codec::encoder::Encoder; + +use bytes::{Buf, BufMut, BytesMut}; +use std::{cmp, fmt, io, str, usize}; + +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LinesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl LinesCodec { + /// Returns a `LinesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() + pub fn new() -> LinesCodec { + LinesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + /// Returns a `LinesCodec` with a maximum line length limit. + /// + /// If this is set, calls to `LinesCodec::decode` will return a + /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that line until a newline + /// character is reached, returning `None` until the line over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `LinesCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the line currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any `\n` characters, causing unbounded memory consumption. + /// + /// [`LinesCodecError`]: crate::codec::LinesCodecError + pub fn new_with_max_length(max_length: usize) -> Self { + LinesCodec { + max_length, + ..LinesCodec::new() + } + } + + /// Returns the maximum line length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new(); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new_with_max_length(256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn utf8(buf: &[u8]) -> Result<&str, io::Error> { + str::from_utf8(buf) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +impl Decoder for LinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + let line = utf8(line)?; + return Ok(Some(line.to_string())); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } + + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { + Ok(match self.decode(buf)? { + Some(frame) => Some(frame), + None => { + // No terminating newline - return remaining data, if any + if buf.is_empty() || buf == &b"\r"[..] { + None + } else { + let line = buf.split_to(buf.len()); + let line = without_carriage_return(&line); + let line = utf8(line)?; + self.next_index = 0; + Some(line.to_string()) + } + } + }) + } +} + +impl Encoder for LinesCodec +where + T: AsRef, +{ + type Error = LinesCodecError; + + fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + let line = line.as_ref(); + buf.reserve(line.len() + 1); + buf.put(line.as_bytes()); + buf.put_u8(b'\n'); + Ok(()) + } +} + +impl Default for LinesCodec { + fn default() -> Self { + Self::new() + } +} + +/// An error occurred while encoding or decoding a line. +#[derive(Debug)] +pub enum LinesCodecError { + /// The maximum line length was exceeded. + MaxLineLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} diff --git a/src/codec/mod.rs b/src/codec/mod.rs new file mode 100644 index 0000000..2295176 --- /dev/null +++ b/src/codec/mod.rs @@ -0,0 +1,290 @@ +//! Adaptors from AsyncRead/AsyncWrite to Stream/Sink +//! +//! Raw I/O objects work with byte sequences, but higher-level code usually +//! wants to batch these into meaningful chunks, called "frames". +//! +//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and +//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. +//! Framed streams are also known as transports. +//! +//! # The Decoder trait +//! +//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an +//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify +//! how sequences of bytes are turned into a sequence of frames, and to +//! determine where the boundaries between frames are. The job of the +//! `FramedRead` is to repeatedly switch between reading more data from the IO +//! resource, and asking the decoder whether we have received enough data to +//! decode another frame of data. +//! +//! The main method on the `Decoder` trait is the [`decode`] method. This method +//! takes as argument the data that has been read so far, and when it is called, +//! it will be in one of the following situations: +//! +//! 1. The buffer contains less than a full frame. +//! 2. The buffer contains exactly a full frame. +//! 3. The buffer contains more than a full frame. +//! +//! In the first situation, the decoder should return `Ok(None)`. +//! +//! In the second situation, the decoder should clear the provided buffer and +//! return `Ok(Some(the_decoded_frame))`. +//! +//! In the third situation, the decoder should use a method such as [`split_to`] +//! or [`advance`] to modify the buffer such that the frame is removed from the +//! buffer, but any data in the buffer after that frame should still remain in +//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in +//! this case. +//! +//! Finally the decoder may return an error if the data is invalid in some way. +//! The decoder should _not_ return an error just because it has yet to receive +//! a full frame. +//! +//! It is guaranteed that, from one call to `decode` to another, the provided +//! buffer will contain the exact same data as before, except that if more data +//! has arrived through the IO resource, that data will have been appended to +//! the buffer. This means that reading frames from a `FramedRead` is +//! essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncReadExt; +//! # // This uses async_stream to create an example that compiles. +//! # fn foo() -> impl futures_core::Stream> { async_stream::try_stream! { +//! # use tokio_util::codec::Decoder; +//! # let mut decoder = tokio_util::codec::BytesCodec::new(); +//! # let io_resource = &mut &[0u8, 1, 2, 3][..]; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! // The read_buf call will append to buf rather than overwrite existing data. +//! let len = io_resource.read_buf(&mut buf).await?; +//! +//! if len == 0 { +//! while let Some(frame) = decoder.decode_eof(&mut buf)? { +//! yield frame; +//! } +//! break; +//! } +//! +//! while let Some(frame) = decoder.decode(&mut buf)? { +//! yield frame; +//! } +//! } +//! # }} +//! ``` +//! The example above uses `yield` whenever the `Stream` produces an item. +//! +//! ## Example decoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The decoder fails with an error +//! if the string data is not valid utf-8 or too long. +//! +//! Such a decoder can be written like this: +//! ``` +//! use tokio_util::codec::Decoder; +//! use bytes::{BytesMut, Buf}; +//! +//! struct MyStringDecoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Decoder for MyStringDecoder { +//! type Item = String; +//! type Error = std::io::Error; +//! +//! fn decode( +//! &mut self, +//! src: &mut BytesMut +//! ) -> Result, Self::Error> { +//! if src.len() < 4 { +//! // Not enough data to read length marker. +//! return Ok(None); +//! } +//! +//! // Read length marker. +//! let mut length_bytes = [0u8; 4]; +//! length_bytes.copy_from_slice(&src[..4]); +//! let length = u32::from_le_bytes(length_bytes) as usize; +//! +//! // Check that the length is not too large to avoid a denial of +//! // service attack where the server runs out of memory. +//! if length > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", length) +//! )); +//! } +//! +//! if src.len() < 4 + length { +//! // The full string has not yet arrived. +//! // +//! // We reserve more space in the buffer. This is not strictly +//! // necessary, but is a good idea performance-wise. +//! src.reserve(4 + length - src.len()); +//! +//! // We inform the Framed that we need more bytes to form the next +//! // frame. +//! return Ok(None); +//! } +//! +//! // Use advance to modify src such that it no longer contains +//! // this frame. +//! let data = src[4..4 + length].to_vec(); +//! src.advance(4 + length); +//! +//! // Convert the data to a string, or fail if it is not valid utf-8. +//! match String::from_utf8(data) { +//! Ok(string) => Ok(Some(string)), +//! Err(utf8_error) => { +//! Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! utf8_error.utf8_error(), +//! )) +//! }, +//! } +//! } +//! } +//! ``` +//! +//! # The Encoder trait +//! +//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn +//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to +//! specify how frames are turned into a sequences of bytes. The job of the +//! `FramedWrite` is to take the resulting sequence of bytes and write it to the +//! IO resource. +//! +//! The main method on the `Encoder` trait is the [`encode`] method. This method +//! takes an item that is being written, and a buffer to write the item to. The +//! buffer may already contain data, and in this case, the encoder should append +//! the new frame the to buffer rather than overwrite the existing data. +//! +//! It is guaranteed that, from one call to `encode` to another, the provided +//! buffer will contain the exact same data as before, except that some of the +//! data may have been removed from the front of the buffer. Writing to a +//! `FramedWrite` is essentially equivalent to the following loop: +//! +//! ```no_run +//! use tokio::io::AsyncWriteExt; +//! use bytes::Buf; // for advance +//! # use tokio_util::codec::Encoder; +//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() } +//! # async fn no_more_frames() { } +//! # #[tokio::main] async fn main() -> std::io::Result<()> { +//! # let mut io_resource = tokio::io::sink(); +//! # let mut encoder = tokio_util::codec::BytesCodec::new(); +//! +//! const MAX: usize = 8192; +//! +//! let mut buf = bytes::BytesMut::new(); +//! loop { +//! tokio::select! { +//! num_written = io_resource.write(&buf), if !buf.is_empty() => { +//! buf.advance(num_written?); +//! }, +//! frame = next_frame(), if buf.len() < MAX => { +//! encoder.encode(frame, &mut buf)?; +//! }, +//! _ = no_more_frames() => { +//! io_resource.write_all(&buf).await?; +//! io_resource.shutdown().await?; +//! return Ok(()); +//! }, +//! } +//! } +//! # } +//! ``` +//! Here the `next_frame` method corresponds to any frames you write to the +//! `FramedWrite`. The `no_more_frames` method corresponds to closing the +//! `FramedWrite` with [`SinkExt::close`]. +//! +//! ## Example encoder +//! +//! As an example, consider a protocol that can be used to send strings where +//! each frame is a four byte integer that contains the length of the frame, +//! followed by that many bytes of string data. The encoder will fail if the +//! string is too long. +//! +//! Such an encoder can be written like this: +//! ``` +//! use tokio_util::codec::Encoder; +//! use bytes::BytesMut; +//! +//! struct MyStringEncoder {} +//! +//! const MAX: usize = 8 * 1024 * 1024; +//! +//! impl Encoder for MyStringEncoder { +//! type Error = std::io::Error; +//! +//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> { +//! // Don't send a string if it is longer than the other end will +//! // accept. +//! if item.len() > MAX { +//! return Err(std::io::Error::new( +//! std::io::ErrorKind::InvalidData, +//! format!("Frame of length {} is too large.", item.len()) +//! )); +//! } +//! +//! // Convert the length into a byte array. +//! // The cast to u32 cannot overflow due to the length check above. +//! let len_slice = u32::to_le_bytes(item.len() as u32); +//! +//! // Reserve space in the buffer. +//! dst.reserve(4 + item.len()); +//! +//! // Write the length and string to the buffer. +//! dst.extend_from_slice(&len_slice); +//! dst.extend_from_slice(item.as_bytes()); +//! Ok(()) +//! } +//! } +//! ``` +//! +//! [`AsyncRead`]: tokio::io::AsyncRead +//! [`AsyncWrite`]: tokio::io::AsyncWrite +//! [`Stream`]: futures_core::Stream +//! [`Sink`]: futures_sink::Sink +//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close +//! [`FramedRead`]: struct@crate::codec::FramedRead +//! [`FramedWrite`]: struct@crate::codec::FramedWrite +//! [`Framed`]: struct@crate::codec::Framed +//! [`Decoder`]: trait@crate::codec::Decoder +//! [`decode`]: fn@crate::codec::Decoder::decode +//! [`encode`]: fn@crate::codec::Encoder::encode +//! [`split_to`]: fn@bytes::BytesMut::split_to +//! [`advance`]: fn@bytes::Buf::advance + +mod bytes_codec; +pub use self::bytes_codec::BytesCodec; + +mod decoder; +pub use self::decoder::Decoder; + +mod encoder; +pub use self::encoder::Encoder; + +mod framed_impl; +#[allow(unused_imports)] +pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; + +mod framed; +pub use self::framed::{Framed, FramedParts}; + +mod framed_read; +pub use self::framed_read::FramedRead; + +mod framed_write; +pub use self::framed_write::FramedWrite; + +pub mod length_delimited; +pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError}; + +mod lines_codec; +pub use self::lines_codec::{LinesCodec, LinesCodecError}; + +mod any_delimiter_codec; +pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError}; diff --git a/src/compat.rs b/src/compat.rs new file mode 100644 index 0000000..6a8802d --- /dev/null +++ b/src/compat.rs @@ -0,0 +1,274 @@ +//! Compatibility between the `tokio::io` and `futures-io` versions of the +//! `AsyncRead` and `AsyncWrite` traits. +use futures_core::ready; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A compatibility layer that allows conversion between the + /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. + #[derive(Copy, Clone, Debug)] + pub struct Compat { + #[pin] + inner: T, + seek_pos: Option, + } +} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. +pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `tokio_io::AsyncRead`. + fn compat(self) -> Compat + where + Self: Sized, + { + Compat::new(self) + } +} + +impl FuturesAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. +pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `tokio::io::AsyncWrite`. + fn compat_write(self) -> Compat + where + Self: Sized, + { + Compat::new(self) + } +} + +impl FuturesAsyncWriteCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. +pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncRead`. + fn compat(self) -> Compat + where + Self: Sized, + { + Compat::new(self) + } +} + +impl TokioAsyncReadCompatExt for T {} + +/// Extension trait that allows converting a type implementing +/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. +pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { + /// Wraps `self` with a compatibility layer that implements + /// `futures_io::AsyncWrite`. + fn compat_write(self) -> Compat + where + Self: Sized, + { + Compat::new(self) + } +} + +impl TokioAsyncWriteCompatExt for T {} + +// === impl Compat === + +impl Compat { + fn new(inner: T) -> Self { + Self { + inner, + seek_pos: None, + } + } + + /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object + /// contained within. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Returns the wrapped item. + pub fn into_inner(self) -> T { + self.inner + } +} + +impl tokio::io::AsyncRead for Compat +where + T: futures_io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + // We can't trust the inner type to not peak at the bytes, + // so we must defensively initialize the buffer. + let slice = buf.initialize_unfilled(); + let n = ready!(futures_io::AsyncRead::poll_read( + self.project().inner, + cx, + slice + ))?; + buf.advance(n); + Poll::Ready(Ok(())) + } +} + +impl futures_io::AsyncRead for Compat +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + slice: &mut [u8], + ) -> Poll> { + let mut buf = tokio::io::ReadBuf::new(slice); + ready!(tokio::io::AsyncRead::poll_read( + self.project().inner, + cx, + &mut buf + ))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +impl tokio::io::AsyncBufRead for Compat +where + T: futures_io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + futures_io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl futures_io::AsyncBufRead for Compat +where + T: tokio::io::AsyncBufRead, +{ + fn poll_fill_buf<'a>( + self: Pin<&'a mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + tokio::io::AsyncBufRead::consume(self.project().inner, amt) + } +} + +impl tokio::io::AsyncWrite for Compat +where + T: futures_io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + futures_io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + futures_io::AsyncWrite::poll_close(self.project().inner, cx) + } +} + +impl futures_io::AsyncWrite for Compat +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } +} + +impl futures_io::AsyncSeek for Compat { + fn poll_seek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: io::SeekFrom, + ) -> Poll> { + if self.seek_pos != Some(pos) { + self.as_mut().project().inner.start_seek(pos)?; + *self.as_mut().project().seek_pos = Some(pos); + } + let res = ready!(self.as_mut().project().inner.poll_complete(cx)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +impl tokio::io::AsyncSeek for Compat { + fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { + *self.as_mut().project().seek_pos = Some(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pos = match self.seek_pos { + None => { + // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + Some(pos) => pos, + }; + let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +#[cfg(unix)] +impl std::os::unix::io::AsRawFd for Compat { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(windows)] +impl std::os::windows::io::AsRawHandle for Compat { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { + self.inner.as_raw_handle() + } +} diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..a7a5e02 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,190 @@ +//! Tokio context aware futures utilities. +//! +//! This module includes utilities around integrating tokio with other runtimes +//! by allowing the context to be attached to futures. This allows spawning +//! futures on other executors while still using tokio to drive them. This +//! can be useful if you need to use a tokio based library in an executor/runtime +//! that does not provide a tokio context. + +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::runtime::{Handle, Runtime}; + +pin_project! { + /// `TokioContext` allows running futures that must be inside Tokio's + /// context on a non-Tokio runtime. + /// + /// It contains a [`Handle`] to the runtime. A handle to the runtime can be + /// obtain by calling the [`Runtime::handle()`] method. + /// + /// Note that the `TokioContext` wrapper only works if the `Runtime` it is + /// connected to has not yet been destroyed. You must keep the `Runtime` + /// alive until the future has finished executing. + /// + /// **Warning:** If `TokioContext` is used together with a [current thread] + /// runtime, that runtime must be inside a call to `block_on` for the + /// wrapped future to work. For this reason, it is recommended to use a + /// [multi thread] runtime, even if you configure it to only spawn one + /// worker thread. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`Handle`]: struct@tokio::runtime::Handle + /// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle + /// [`RuntimeExt`]: trait@crate::context::RuntimeExt + /// [`new_static`]: fn@Self::new_static + /// [`sleep`]: fn@tokio::time::sleep + /// [current thread]: fn@tokio::runtime::Builder::new_current_thread + /// [enables time]: fn@tokio::runtime::Builder::enable_time + /// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread + pub struct TokioContext { + #[pin] + inner: F, + handle: Handle, + } +} + +impl TokioContext { + /// Associate the provided future with the context of the runtime behind + /// the provided `Handle`. + /// + /// This constructor uses a `'static` lifetime to opt-out of checking that + /// the runtime still exists. + /// + /// # Examples + /// + /// This is the same as the example above, but uses the `new` constructor + /// rather than [`RuntimeExt::wrap`]. + /// + /// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::TokioContext; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// let fut = TokioContext::new( + /// async { sleep(Duration::from_millis(2)).await }, + /// rt.handle().clone(), + /// ); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + pub fn new(future: F, handle: Handle) -> TokioContext { + TokioContext { + inner: future, + handle, + } + } + + /// Obtain a reference to the handle inside this `TokioContext`. + pub fn handle(&self) -> &Handle { + &self.handle + } + + /// Remove the association between the Tokio runtime and the wrapped future. + pub fn into_inner(self) -> F { + self.inner + } +} + +impl Future for TokioContext { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let me = self.project(); + let handle = me.handle; + let fut = me.inner; + + let _enter = handle.enter(); + fut.poll(cx) + } +} + +/// Extension trait that simplifies bundling a `Handle` with a `Future`. +pub trait RuntimeExt { + /// Create a [`TokioContext`] that wraps the provided future and runs it in + /// this runtime's context. + /// + /// # Examples + /// + /// This example creates two runtimes, but only [enables time] on one of + /// them. It then uses the context of the runtime with the timer enabled to + /// execute a [`sleep`] future on the runtime with timing disabled. + /// + /// ``` + /// use tokio::time::{sleep, Duration}; + /// use tokio_util::context::RuntimeExt; + /// + /// // This runtime has timers enabled. + /// let rt = tokio::runtime::Builder::new_multi_thread() + /// .enable_all() + /// .build() + /// .unwrap(); + /// + /// // This runtime has timers disabled. + /// let rt2 = tokio::runtime::Builder::new_multi_thread() + /// .build() + /// .unwrap(); + /// + /// // Wrap the sleep future in the context of rt. + /// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await }); + /// + /// // Execute the future on rt2. + /// rt2.block_on(fut); + /// ``` + /// + /// [`TokioContext`]: struct@crate::context::TokioContext + /// [`sleep`]: fn@tokio::time::sleep + /// [enables time]: fn@tokio::runtime::Builder::enable_time + fn wrap(&self, fut: F) -> TokioContext; +} + +impl RuntimeExt for Runtime { + fn wrap(&self, fut: F) -> TokioContext { + TokioContext { + inner: fut, + handle: self.handle().clone(), + } + } +} diff --git a/src/either.rs b/src/either.rs new file mode 100644 index 0000000..9225e53 --- /dev/null +++ b/src/either.rs @@ -0,0 +1,188 @@ +//! Module defining an Either type. +use std::{ + future::Future, + io::SeekFrom, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; + +/// Combines two different futures, streams, or sinks having the same associated types into a single type. +/// +/// This type implements common asynchronous traits such as [`Future`] and those in Tokio. +/// +/// [`Future`]: std::future::Future +/// +/// # Example +/// +/// The following code will not work: +/// +/// ```compile_fail +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// some_async_function() +/// } else { +/// other_async_function() // <- Will print: "`if` and `else` have incompatible types" +/// }; +/// +/// println!("Result is {}", result.await); +/// } +/// ``` +/// +// This is because although the output types for both futures is the same, the exact future +// types are different, but the compiler must be able to choose a single type for the +// `result` variable. +/// +/// When the output type is the same, we can wrap each future in `Either` to avoid the +/// issue: +/// +/// ``` +/// use tokio_util::either::Either; +/// # fn some_condition() -> bool { true } +/// # async fn some_async_function() -> u32 { 10 } +/// # async fn other_async_function() -> u32 { 20 } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = if some_condition() { +/// Either::Left(some_async_function()) +/// } else { +/// Either::Right(other_async_function()) +/// }; +/// +/// let value = result.await; +/// println!("Result is {}", value); +/// # assert_eq!(value, 10); +/// } +/// ``` +#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. +#[derive(Debug, Clone)] +pub enum Either { + Left(L), + Right(R), +} + +/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. +/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either +/// enum variant held in `self`. +macro_rules! delegate_call { + ($self:ident.$method:ident($($args:ident),+)) => { + unsafe { + match $self.get_unchecked_mut() { + Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), + Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), + } + } + } +} + +impl Future for Either +where + L: Future, + R: Future, +{ + type Output = O; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + delegate_call!(self.poll(cx)) + } +} + +impl AsyncRead for Either +where + L: AsyncRead, + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + delegate_call!(self.poll_read(cx, buf)) + } +} + +impl AsyncBufRead for Either +where + L: AsyncBufRead, + R: AsyncBufRead, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_fill_buf(cx)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + delegate_call!(self.consume(amt)) + } +} + +impl AsyncSeek for Either +where + L: AsyncSeek, + R: AsyncSeek, +{ + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { + delegate_call!(self.start_seek(position)) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_complete(cx)) + } +} + +impl AsyncWrite for Either +where + L: AsyncWrite, + R: AsyncWrite, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + delegate_call!(self.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_shutdown(cx)) + } +} + +impl futures_core::stream::Stream for Either +where + L: futures_core::stream::Stream, + R: futures_core::stream::Stream, +{ + type Item = L::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + delegate_call!(self.poll_next(cx)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{repeat, AsyncReadExt, Repeat}; + use tokio_stream::{once, Once, StreamExt}; + + #[tokio::test] + async fn either_is_stream() { + let mut either: Either, Once> = Either::Left(once(1)); + + assert_eq!(Some(1u32), either.next().await); + } + + #[tokio::test] + async fn either_is_async_read() { + let mut buffer = [0; 3]; + let mut either: Either = Either::Right(repeat(0b101)); + + either.read_exact(&mut buffer).await.unwrap(); + assert_eq!(buffer, [0b101, 0b101, 0b101]); + } +} diff --git a/src/io/copy_to_bytes.rs b/src/io/copy_to_bytes.rs new file mode 100644 index 0000000..9509e71 --- /dev/null +++ b/src/io/copy_to_bytes.rs @@ -0,0 +1,68 @@ +use bytes::Bytes; +use futures_sink::Sink; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A helper that wraps a [`Sink`]`<`[`Bytes`]`>` and converts it into a + /// [`Sink`]`<&'a [u8]>` by copying each byte slice into an owned [`Bytes`]. + /// + /// See the documentation for [`SinkWriter`] for an example. + /// + /// [`Bytes`]: bytes::Bytes + /// [`SinkWriter`]: crate::io::SinkWriter + /// [`Sink`]: futures_sink::Sink + #[derive(Debug)] + pub struct CopyToBytes { + #[pin] + inner: S, + } +} + +impl CopyToBytes { + /// Creates a new [`CopyToBytes`]. + pub fn new(inner: S) -> Self { + Self { inner } + } + + /// Gets a reference to the underlying sink. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying sink. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consumes this [`CopyToBytes`], returning the underlying sink. + pub fn into_inner(self) -> S { + self.inner + } +} + +impl<'a, S> Sink<&'a [u8]> for CopyToBytes +where + S: Sink, +{ + type Error = S::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: &'a [u8]) -> Result<(), Self::Error> { + self.project() + .inner + .start_send(Bytes::copy_from_slice(item)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} diff --git a/src/io/inspect.rs b/src/io/inspect.rs new file mode 100644 index 0000000..ec5bb97 --- /dev/null +++ b/src/io/inspect.rs @@ -0,0 +1,134 @@ +use futures_core::ready; +use pin_project_lite::pin_project; +use std::io::{IoSlice, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pin_project! { + /// An adapter that lets you inspect the data that's being read. + /// + /// This is useful for things like hashing data as it's read in. + pub struct InspectReader { + #[pin] + reader: R, + f: F, + } +} + +impl InspectReader { + /// Create a new InspectReader, wrapping `reader` and calling `f` for the + /// new data supplied by each read call. + /// + /// The closure will only be called with an empty slice if the inner reader + /// returns without reading data into the buffer. This happens at EOF, or if + /// `poll_read` is called with a zero-size buffer. + pub fn new(reader: R, f: F) -> InspectReader + where + R: AsyncRead, + F: FnMut(&[u8]), + { + InspectReader { reader, f } + } + + /// Consumes the `InspectReader`, returning the wrapped reader + pub fn into_inner(self) -> R { + self.reader + } +} + +impl AsyncRead for InspectReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.project(); + let filled_length = buf.filled().len(); + ready!(me.reader.poll_read(cx, buf))?; + (me.f)(&buf.filled()[filled_length..]); + Poll::Ready(Ok(())) + } +} + +pin_project! { + /// An adapter that lets you inspect the data that's being written. + /// + /// This is useful for things like hashing data as it's written out. + pub struct InspectWriter { + #[pin] + writer: W, + f: F, + } +} + +impl InspectWriter { + /// Create a new InspectWriter, wrapping `write` and calling `f` for the + /// data successfully written by each write call. + /// + /// The closure `f` will never be called with an empty slice. A vectored + /// write can result in multiple calls to `f` - at most one call to `f` per + /// buffer supplied to `poll_write_vectored`. + pub fn new(writer: W, f: F) -> InspectWriter + where + W: AsyncWrite, + F: FnMut(&[u8]), + { + InspectWriter { writer, f } + } + + /// Consumes the `InspectWriter`, returning the wrapped writer + pub fn into_inner(self) -> W { + self.writer + } +} + +impl AsyncWrite for InspectWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write(cx, buf); + if let Poll::Ready(Ok(count)) = res { + if count != 0 { + (me.f)(&buf[..count]); + } + } + res + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + me.writer.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let me = self.project(); + let res = me.writer.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(mut count)) = res { + for buf in bufs { + if count == 0 { + break; + } + let size = count.min(buf.len()); + if size != 0 { + (me.f)(&buf[..size]); + count -= size; + } + } + } + res + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 0000000..6c40d73 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,31 @@ +//! Helpers for IO related tasks. +//! +//! The stream types are often used in combination with hyper or reqwest, as they +//! allow converting between a hyper [`Body`] and [`AsyncRead`]. +//! +//! The [`SyncIoBridge`] type converts from the world of async I/O +//! to synchronous I/O; this may often come up when using synchronous APIs +//! inside [`tokio::task::spawn_blocking`]. +//! +//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html +//! [`AsyncRead`]: tokio::io::AsyncRead + +mod copy_to_bytes; +mod inspect; +mod read_buf; +mod reader_stream; +mod sink_writer; +mod stream_reader; + +cfg_io_util! { + mod sync_bridge; + pub use self::sync_bridge::SyncIoBridge; +} + +pub use self::copy_to_bytes::CopyToBytes; +pub use self::inspect::{InspectReader, InspectWriter}; +pub use self::read_buf::read_buf; +pub use self::reader_stream::ReaderStream; +pub use self::sink_writer::SinkWriter; +pub use self::stream_reader::StreamReader; +pub use crate::util::{poll_read_buf, poll_write_buf}; diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs new file mode 100644 index 0000000..d7938a3 --- /dev/null +++ b/src/io/read_buf.rs @@ -0,0 +1,65 @@ +use bytes::BufMut; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. +/// +/// [`BufMut`]: bytes::BufMut +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio_stream as stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, read_buf}; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = read_buf(&mut read, &mut buf).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +pub async fn read_buf(read: &mut R, buf: &mut B) -> io::Result +where + R: AsyncRead + Unpin, + B: BufMut, +{ + return ReadBufFn(read, buf).await; + + struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B); + + impl<'a, R, B> Future for ReadBufFn<'a, R, B> + where + R: AsyncRead + Unpin, + B: BufMut, + { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + crate::util::poll_read_buf(Pin::new(this.0), cx, this.1) + } + } +} diff --git a/src/io/reader_stream.rs b/src/io/reader_stream.rs new file mode 100644 index 0000000..866c114 --- /dev/null +++ b/src/io/reader_stream.rs @@ -0,0 +1,118 @@ +use bytes::{Bytes, BytesMut}; +use futures_core::stream::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; + +const DEFAULT_CAPACITY: usize = 4096; + +pin_project! { + /// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks. + /// + /// This stream is fused. It performs the inverse operation of + /// [`StreamReader`]. + /// + /// # Example + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// use tokio_stream::StreamExt; + /// use tokio_util::io::ReaderStream; + /// + /// // Create a stream of data. + /// let data = b"hello, world!"; + /// let mut stream = ReaderStream::new(&data[..]); + /// + /// // Read all of the chunks into a vector. + /// let mut stream_contents = Vec::new(); + /// while let Some(chunk) = stream.next().await { + /// stream_contents.extend_from_slice(&chunk?); + /// } + /// + /// // Once the chunks are concatenated, we should have the + /// // original data. + /// assert_eq!(stream_contents, data); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`StreamReader`]: crate::io::StreamReader + /// [`Stream`]: futures_core::Stream + #[derive(Debug)] + pub struct ReaderStream { + // Reader itself. + // + // This value is `None` if the stream has terminated. + #[pin] + reader: Option, + // Working buffer, used to optimize allocations. + buf: BytesMut, + capacity: usize, + } +} + +impl ReaderStream { + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result`. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn new(reader: R) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::new(), + capacity: DEFAULT_CAPACITY, + } + } + + /// Convert an [`AsyncRead`] into a [`Stream`] with item type + /// `Result`, + /// with a specific read buffer initial capacity. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Stream`]: futures_core::Stream + pub fn with_capacity(reader: R, capacity: usize) -> Self { + ReaderStream { + reader: Some(reader), + buf: BytesMut::with_capacity(capacity), + capacity, + } + } +} + +impl Stream for ReaderStream { + type Item = std::io::Result; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use crate::util::poll_read_buf; + + let mut this = self.as_mut().project(); + + let reader = match this.reader.as_pin_mut() { + Some(r) => r, + None => return Poll::Ready(None), + }; + + if this.buf.capacity() == 0 { + this.buf.reserve(*this.capacity); + } + + match poll_read_buf(reader, cx, &mut this.buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + self.project().reader.set(None); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(Ok(0)) => { + self.project().reader.set(None); + Poll::Ready(None) + } + Poll::Ready(Ok(_)) => { + let chunk = this.buf.split(); + Poll::Ready(Some(Ok(chunk.freeze()))) + } + } + } +} diff --git a/src/io/sink_writer.rs b/src/io/sink_writer.rs new file mode 100644 index 0000000..f2af262 --- /dev/null +++ b/src/io/sink_writer.rs @@ -0,0 +1,117 @@ +use futures_core::ready; +use futures_sink::Sink; + +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +pin_project! { + /// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`]. + /// + /// Whenever you write to this [`SinkWriter`], the supplied bytes are + /// forwarded to the inner [`Sink`]. When `shutdown` is called on this + /// [`SinkWriter`], the inner sink is closed. + /// + /// This adapter takes a `Sink<&[u8]>` and provides an [`AsyncWrite`] impl + /// for it. Because of the lifetime, this trait is relatively rarely + /// implemented. The main ways to get a `Sink<&[u8]>` that you can use with + /// this type are: + /// + /// * With the codec module by implementing the [`Encoder`]`<&[u8]>` trait. + /// * By wrapping a `Sink` in a [`CopyToBytes`]. + /// * Manually implementing `Sink<&[u8]>` directly. + /// + /// The opposite conversion of implementing `Sink<_>` for an [`AsyncWrite`] + /// is done using the [`codec`] module. + /// + /// # Example + /// + /// ``` + /// use bytes::Bytes; + /// use futures_util::SinkExt; + /// use std::io::{Error, ErrorKind}; + /// use tokio::io::AsyncWriteExt; + /// use tokio_util::io::{SinkWriter, CopyToBytes}; + /// use tokio_util::sync::PollSender; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() -> Result<(), Error> { + /// // We use an mpsc channel as an example of a `Sink`. + /// let (tx, mut rx) = tokio::sync::mpsc::channel::(1); + /// let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe)); + /// + /// // Wrap it in `CopyToBytes` to get a `Sink<&[u8]>`. + /// let mut writer = SinkWriter::new(CopyToBytes::new(sink)); + /// + /// // Write data to our interface... + /// let data: [u8; 4] = [1, 2, 3, 4]; + /// let _ = writer.write(&data).await?; + /// + /// // ... and receive it. + /// assert_eq!(data.as_slice(), &*rx.recv().await.unwrap()); + /// # Ok(()) + /// # } + /// ``` + /// + /// [`AsyncWrite`]: tokio::io::AsyncWrite + /// [`CopyToBytes`]: crate::io::CopyToBytes + /// [`Encoder`]: crate::codec::Encoder + /// [`Sink`]: futures_sink::Sink + /// [`codec`]: tokio_util::codec + #[derive(Debug)] + pub struct SinkWriter { + #[pin] + inner: S, + } +} + +impl SinkWriter { + /// Creates a new [`SinkWriter`]. + pub fn new(sink: S) -> Self { + Self { inner: sink } + } + + /// Gets a reference to the underlying sink. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying sink. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consumes this [`SinkWriter`], returning the underlying sink. + pub fn into_inner(self) -> S { + self.inner + } +} +impl AsyncWrite for SinkWriter +where + for<'a> S: Sink<&'a [u8], Error = E>, + E: Into, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + + ready!(this.inner.as_mut().poll_ready(cx).map_err(Into::into))?; + match this.inner.as_mut().start_send(buf) { + Ok(()) => Poll::Ready(Ok(buf.len())), + Err(e) => Poll::Ready(Err(e.into())), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx).map_err(Into::into) + } +} diff --git a/src/io/stream_reader.rs b/src/io/stream_reader.rs new file mode 100644 index 0000000..3353722 --- /dev/null +++ b/src/io/stream_reader.rs @@ -0,0 +1,326 @@ +use bytes::Buf; +use futures_core::stream::Stream; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +/// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. +/// +/// This type performs the inverse operation of [`ReaderStream`]. +/// +/// This type also implements the [`AsyncBufRead`] trait, so you can use it +/// to read a `Stream` of byte chunks line-by-line. See the examples below. +/// +/// # Example +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::{AsyncReadExt, Result}; +/// use tokio_util::io::StreamReader; +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a stream from an iterator. +/// let stream = tokio_stream::iter(vec![ +/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), +/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), +/// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), +/// ]); +/// +/// // Convert it to an AsyncRead. +/// let mut read = StreamReader::new(stream); +/// +/// // Read five bytes from the stream. +/// let mut buf = [0; 5]; +/// read.read_exact(&mut buf).await?; +/// assert_eq!(buf, [0, 1, 2, 3, 4]); +/// +/// // Read the rest of the current chunk. +/// assert_eq!(read.read(&mut buf).await?, 3); +/// assert_eq!(&buf[..3], [5, 6, 7]); +/// +/// // Read the next chunk. +/// assert_eq!(read.read(&mut buf).await?, 4); +/// assert_eq!(&buf[..4], [8, 9, 10, 11]); +/// +/// // We have now reached the end. +/// assert_eq!(read.read(&mut buf).await?, 0); +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// If the stream produces errors which are not [`std::io::Error`], +/// the errors can be converted using [`StreamExt`] to map each +/// element. +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::AsyncReadExt; +/// use tokio_util::io::StreamReader; +/// use tokio_stream::StreamExt; +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a stream from an iterator, including an error. +/// let stream = tokio_stream::iter(vec![ +/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), +/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), +/// Result::Err("Something bad happened!") +/// ]); +/// +/// // Use StreamExt to map the stream and error to a std::io::Error +/// let stream = stream.map(|result| result.map_err(|err| { +/// std::io::Error::new(std::io::ErrorKind::Other, err) +/// })); +/// +/// // Convert it to an AsyncRead. +/// let mut read = StreamReader::new(stream); +/// +/// // Read five bytes from the stream. +/// let mut buf = [0; 5]; +/// read.read_exact(&mut buf).await?; +/// assert_eq!(buf, [0, 1, 2, 3, 4]); +/// +/// // Read the rest of the current chunk. +/// assert_eq!(read.read(&mut buf).await?, 3); +/// assert_eq!(&buf[..3], [5, 6, 7]); +/// +/// // Reading the next chunk will produce an error +/// let error = read.read(&mut buf).await.unwrap_err(); +/// assert_eq!(error.kind(), std::io::ErrorKind::Other); +/// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!"); +/// +/// // We have now reached the end. +/// assert_eq!(read.read(&mut buf).await?, 0); +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks +/// line-by-line. Note that you will usually also need to convert the error +/// type when doing this. See the second example for an explanation of how +/// to do this. +/// +/// ``` +/// use tokio::io::{Result, AsyncBufReadExt}; +/// use tokio_util::io::StreamReader; +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a stream of byte chunks. +/// let stream = tokio_stream::iter(vec![ +/// Result::Ok(b"The first line.\n".as_slice()), +/// Result::Ok(b"The second line.".as_slice()), +/// Result::Ok(b"\nThe third".as_slice()), +/// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()), +/// ]); +/// +/// // Convert it to an AsyncRead. +/// let mut read = StreamReader::new(stream); +/// +/// // Loop through the lines from the `StreamReader`. +/// let mut line = String::new(); +/// let mut lines = Vec::new(); +/// loop { +/// line.clear(); +/// let len = read.read_line(&mut line).await?; +/// if len == 0 { break; } +/// lines.push(line.clone()); +/// } +/// +/// // Verify that we got the lines we expected. +/// assert_eq!( +/// lines, +/// vec![ +/// "The first line.\n", +/// "The second line.\n", +/// "The third line.\n", +/// "The fourth line.\n", +/// "The fifth line.\n", +/// ] +/// ); +/// # Ok(()) +/// # } +/// ``` +/// +/// [`AsyncRead`]: tokio::io::AsyncRead +/// [`AsyncBufRead`]: tokio::io::AsyncBufRead +/// [`Stream`]: futures_core::Stream +/// [`ReaderStream`]: crate::io::ReaderStream +/// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html +#[derive(Debug)] +pub struct StreamReader { + // This field is pinned. + inner: S, + // This field is not pinned. + chunk: Option, +} + +impl StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). + /// + /// The item should be a [`Result`] with the ok variant being something that + /// implements the [`Buf`] trait (e.g. `Vec` or `Bytes`). The error + /// should be convertible into an [io error]. + /// + /// [`Result`]: std::result::Result + /// [`Buf`]: bytes::Buf + /// [io error]: std::io::Error + pub fn new(stream: S) -> Self { + Self { + inner: stream, + chunk: None, + } + } + + /// Do we have a chunk and is it non-empty? + fn has_chunk(&self) -> bool { + if let Some(ref chunk) = self.chunk { + chunk.remaining() > 0 + } else { + false + } + } + + /// Consumes this `StreamReader`, returning a Tuple consisting + /// of the underlying stream and an Option of the internal buffer, + /// which is Some in case the buffer contains elements. + pub fn into_inner_with_chunk(self) -> (S, Option) { + if self.has_chunk() { + (self.inner, self.chunk) + } else { + (self.inner, None) + } + } +} + +impl StreamReader { + /// Gets a reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Gets a mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Gets a pinned mutable reference to the underlying stream. + /// + /// It is inadvisable to directly read from the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { + self.project().inner + } + + /// Consumes this `BufWriter`, returning the underlying stream. + /// + /// Note that any leftover data in the internal buffer is lost. + /// If you additionally want access to the internal buffer use + /// [`into_inner_with_chunk`]. + /// + /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk + pub fn into_inner(self) -> S { + self.inner + } +} + +impl AsyncRead for StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for StreamReader +where + S: Stream>, + B: Buf, + E: Into, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.as_mut().has_chunk() { + // This unwrap is very sad, but it can't be avoided. + let buf = self.project().chunk.as_ref().unwrap().chunk(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + // Go around the loop in case the chunk is empty. + *self.as_mut().project().chunk = Some(chunk); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} + +// The code below is a manual expansion of the code that pin-project-lite would +// generate. This is done because pin-project-lite fails by hitting the recusion +// limit on this struct. (Every line of documentation is handled recursively by +// the macro.) + +impl Unpin for StreamReader {} + +struct StreamReaderProject<'a, S, B> { + inner: Pin<&'a mut S>, + chunk: &'a mut Option, +} + +impl StreamReader { + #[inline] + fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> { + // SAFETY: We define that only `inner` should be pinned when `Self` is + // and have an appropriate `impl Unpin` for this. + let me = unsafe { Pin::into_inner_unchecked(self) }; + StreamReaderProject { + inner: unsafe { Pin::new_unchecked(&mut me.inner) }, + chunk: &mut me.chunk, + } + } +} diff --git a/src/io/sync_bridge.rs b/src/io/sync_bridge.rs new file mode 100644 index 0000000..f87bfbb --- /dev/null +++ b/src/io/sync_bridge.rs @@ -0,0 +1,143 @@ +use std::io::{BufRead, Read, Write}; +use tokio::io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, +}; + +/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or +/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. +#[derive(Debug)] +pub struct SyncIoBridge { + src: T, + rt: tokio::runtime::Handle, +} + +impl BufRead for SyncIoBridge { + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + let src = &mut self.src; + self.rt.block_on(AsyncBufReadExt::fill_buf(src)) + } + + fn consume(&mut self, amt: usize) { + let src = &mut self.src; + AsyncBufReadExt::consume(src, amt) + } + + fn read_until(&mut self, byte: u8, buf: &mut Vec) -> std::io::Result { + let src = &mut self.src; + self.rt + .block_on(AsyncBufReadExt::read_until(src, byte, buf)) + } + fn read_line(&mut self, buf: &mut String) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(AsyncBufReadExt::read_line(src, buf)) + } +} + +impl Read for SyncIoBridge { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(AsyncReadExt::read(src, buf)) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(src.read_to_end(buf)) + } + + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(src.read_to_string(buf)) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + let src = &mut self.src; + // The AsyncRead trait returns the count, synchronous doesn't. + let _n = self.rt.block_on(src.read_exact(buf))?; + Ok(()) + } +} + +impl Write for SyncIoBridge { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(src.write(buf)) + } + + fn flush(&mut self) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.flush()) + } + + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.write_all(buf)) + } + + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + let src = &mut self.src; + self.rt.block_on(src.write_vectored(bufs)) + } +} + +// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time +// of this writing still unstable, we expose this as part of a standalone method. +impl SyncIoBridge { + /// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes. + /// + /// See [`tokio::io::AsyncWrite::is_write_vectored`]. + pub fn is_write_vectored(&self) -> bool { + self.src.is_write_vectored() + } +} + +impl SyncIoBridge { + /// Shutdown this writer. This method provides a way to call the [`AsyncWriteExt::shutdown`] + /// function of the inner [`tokio::io::AsyncWrite`] instance. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::shutdown`]. + /// + /// [`AsyncWriteExt::shutdown`]: tokio::io::AsyncWriteExt::shutdown + pub fn shutdown(&mut self) -> std::io::Result<()> { + let src = &mut self.src; + self.rt.block_on(src.shutdown()) + } +} + +impl SyncIoBridge { + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`]. + /// It is hence OK to move this struct into a separate thread outside the runtime, as created + /// by e.g. [`tokio::task::spawn_blocking`]. + /// + /// Stated even more strongly: to make use of this bridge, you *must* move + /// it into a separate thread outside the runtime. The synchronous I/O will use the + /// underlying handle to block on the backing asynchronous source, via + /// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that + /// function, an attempt to `block_on` from an asynchronous execution context + /// will panic. + /// + /// # Wrapping `!Unpin` types + /// + /// Use e.g. `SyncIoBridge::new(Box::pin(src))`. + /// + /// # Panics + /// + /// This will panic if called outside the context of a Tokio runtime. + #[track_caller] + pub fn new(src: T) -> Self { + Self::new_with_handle(src, tokio::runtime::Handle::current()) + } + + /// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or + /// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`]. + /// + /// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may + /// be initially invoked outside of an asynchronous context. + pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self { + Self { src, rt } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..524fc47 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,205 @@ +#![allow(clippy::needless_doctest_main)] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +#![doc(test( + no_crate_inject, + attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Utilities for working with Tokio. +//! +//! This crate is not versioned in lockstep with the core +//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's +//! semantic versioning policy, especially with regard to breaking changes. +//! +//! [`tokio`]: https://docs.rs/tokio + +#[macro_use] +mod cfg; + +mod loom; + +cfg_codec! { + pub mod codec; +} + +cfg_net! { + #[cfg(not(target_arch = "wasm32"))] + pub mod udp; + pub mod net; +} + +cfg_compat! { + pub mod compat; +} + +cfg_io! { + pub mod io; +} + +cfg_rt! { + pub mod context; + pub mod task; +} + +cfg_time! { + pub mod time; +} + +pub mod sync; + +pub mod either; + +#[cfg(any(feature = "io", feature = "codec"))] +mod util { + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + + use bytes::{Buf, BufMut}; + use futures_core::ready; + use std::io::{self, IoSlice}; + use std::mem::MaybeUninit; + use std::pin::Pin; + use std::task::{Context, Poll}; + + /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. + /// + /// [`BufMut`]: bytes::Buf + /// + /// # Example + /// + /// ``` + /// use bytes::{Bytes, BytesMut}; + /// use tokio_stream as stream; + /// use tokio::io::Result; + /// use tokio_util::io::{StreamReader, poll_read_buf}; + /// use futures::future::poll_fn; + /// use std::pin::Pin; + /// # #[tokio::main] + /// # async fn main() -> std::io::Result<()> { + /// + /// // Create a reader from an iterator. This particular reader will always be + /// // ready. + /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); + /// + /// let mut buf = BytesMut::new(); + /// let mut reads = 0; + /// + /// loop { + /// reads += 1; + /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; + /// + /// if n == 0 { + /// break; + /// } + /// } + /// + /// // one or more reads might be necessary. + /// assert!(reads >= 1); + /// assert_eq!(&buf[..], &[0, 1, 2, 3]); + /// # Ok(()) + /// # } + /// ``` + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + + // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a + // transparent wrapper around `[MaybeUninit]`. + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) + } + + /// Try to write data from an implementer of the [`Buf`] trait to an + /// [`AsyncWrite`], advancing the buffer's internal cursor. + /// + /// This function will use [vectored writes] when the [`AsyncWrite`] supports + /// vectored writes. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements + /// [`Buf`]: + /// + /// ```no_run + /// use tokio_util::io::poll_write_buf; + /// use tokio::io; + /// use tokio::fs::File; + /// + /// use bytes::Buf; + /// use std::io::Cursor; + /// use std::pin::Pin; + /// use futures::future::poll_fn; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buf = Cursor::new(b"data to write"); + /// + /// // Loop until the entire contents of the buffer are written to + /// // the file. + /// while buf.has_remaining() { + /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; + /// } + /// + /// Ok(()) + /// } + /// ``` + /// + /// [`Buf`]: bytes::Buf + /// [`AsyncWrite`]: tokio::io::AsyncWrite + /// [`File`]: tokio::fs::File + /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored + #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] + pub fn poll_write_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.chunks_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.chunk()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) + } +} diff --git a/src/loom.rs b/src/loom.rs new file mode 100644 index 0000000..dd03fea --- /dev/null +++ b/src/loom.rs @@ -0,0 +1 @@ +pub(crate) use std::sync; diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 0000000..4817e10 --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,97 @@ +//! TCP/UDP/Unix helpers for tokio. + +use crate::either::Either; +use std::future::Future; +use std::io::Result; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[cfg(unix)] +pub mod unix; + +/// A trait for a listener: `TcpListener` and `UnixListener`. +pub trait Listener { + /// The stream's type of this listener. + type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite; + /// The socket address type of this listener. + type Addr; + + /// Polls to accept a new incoming connection to this listener. + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// Accepts a new incoming connection from this listener. + fn accept(&mut self) -> ListenerAcceptFut<'_, Self> + where + Self: Sized, + { + ListenerAcceptFut { listener: self } + } + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> Result; +} + +impl Listener for tokio::net::TcpListener { + type Io = tokio::net::TcpStream; + type Addr = std::net::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result { + self.local_addr().map(Into::into) + } +} + +/// Future for accepting a new connection from a listener. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ListenerAcceptFut<'a, L> { + listener: &'a mut L, +} + +impl<'a, L> Future for ListenerAcceptFut<'a, L> +where + L: Listener, +{ + type Output = Result<(L::Io, L::Addr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.listener.poll_accept(cx) + } +} + +impl Either +where + L: Listener, + R: Listener, +{ + /// Accepts a new incoming connection from this listener. + pub async fn accept(&mut self) -> Result> { + match self { + Either::Left(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Left((stream, addr))) + } + Either::Right(listener) => { + let (stream, addr) = listener.accept().await?; + Ok(Either::Right((stream, addr))) + } + } + } + + /// Returns the local address that this listener is bound to. + pub fn local_addr(&self) -> Result> { + match self { + Either::Left(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Left(addr)) + } + Either::Right(listener) => { + let addr = listener.local_addr()?; + Ok(Either::Right(addr)) + } + } + } +} diff --git a/src/net/unix/mod.rs b/src/net/unix/mod.rs new file mode 100644 index 0000000..0b522c9 --- /dev/null +++ b/src/net/unix/mod.rs @@ -0,0 +1,18 @@ +//! Unix domain socket helpers. + +use super::Listener; +use std::io::Result; +use std::task::{Context, Poll}; + +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + Self::poll_accept(self, cx) + } + + fn local_addr(&self) -> Result { + self.local_addr().map(Into::into) + } +} diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs new file mode 100644 index 0000000..c44be69 --- /dev/null +++ b/src/sync/cancellation_token.rs @@ -0,0 +1,324 @@ +//! An asynchronously awaitable `CancellationToken`. +//! The token allows to signal a cancellation request to one or more tasks. +pub(crate) mod guard; +mod tree_node; + +use crate::loom::sync::Arc; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use guard::DropGuard; +use pin_project_lite::pin_project; + +/// A token which can be used to signal a cancellation request to one or more +/// tasks. +/// +/// Tasks can call [`CancellationToken::cancelled()`] in order to +/// obtain a Future which will be resolved when cancellation is requested. +/// +/// Cancellation can be requested through the [`CancellationToken::cancel`] method. +/// +/// # Examples +/// +/// ```no_run +/// use tokio::select; +/// use tokio_util::sync::CancellationToken; +/// +/// #[tokio::main] +/// async fn main() { +/// let token = CancellationToken::new(); +/// let cloned_token = token.clone(); +/// +/// let join_handle = tokio::spawn(async move { +/// // Wait for either cancellation or a very long time +/// select! { +/// _ = cloned_token.cancelled() => { +/// // The token was cancelled +/// 5 +/// } +/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { +/// 99 +/// } +/// } +/// }); +/// +/// tokio::spawn(async move { +/// tokio::time::sleep(std::time::Duration::from_millis(10)).await; +/// token.cancel(); +/// }); +/// +/// assert_eq!(5, join_handle.await.unwrap()); +/// } +/// ``` +pub struct CancellationToken { + inner: Arc, +} + +impl std::panic::UnwindSafe for CancellationToken {} +impl std::panic::RefUnwindSafe for CancellationToken {} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFuture<'a> { + cancellation_token: &'a CancellationToken, + #[pin] + future: tokio::sync::futures::Notified<'a>, + } +} + +pin_project! { + /// A Future that is resolved once the corresponding [`CancellationToken`] + /// is cancelled. + /// + /// This is the counterpart to [`WaitForCancellationFuture`] that takes + /// [`CancellationToken`] by value instead of using a reference. + #[must_use = "futures do nothing unless polled"] + pub struct WaitForCancellationFutureOwned { + // Since `future` is the first field, it is dropped before the + // cancellation_token field. This ensures that the reference inside the + // `Notified` remains valid. + #[pin] + future: tokio::sync::futures::Notified<'static>, + cancellation_token: CancellationToken, + } +} + +// ===== impl CancellationToken ===== + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken") + .field("is_cancelled", &self.is_cancelled()) + .finish() + } +} + +impl Clone for CancellationToken { + fn clone(&self) -> Self { + tree_node::increase_handle_refcount(&self.inner); + CancellationToken { + inner: self.inner.clone(), + } + } +} + +impl Drop for CancellationToken { + fn drop(&mut self) { + tree_node::decrease_handle_refcount(&self.inner); + } +} + +impl Default for CancellationToken { + fn default() -> CancellationToken { + CancellationToken::new() + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the non-cancelled state. + pub fn new() -> CancellationToken { + CancellationToken { + inner: Arc::new(tree_node::TreeNode::new()), + } + } + + /// Creates a `CancellationToken` which will get cancelled whenever the + /// current token gets cancelled. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::select; + /// use tokio_util::sync::CancellationToken; + /// + /// #[tokio::main] + /// async fn main() { + /// let token = CancellationToken::new(); + /// let child_token = token.child_token(); + /// + /// let join_handle = tokio::spawn(async move { + /// // Wait for either cancellation or a very long time + /// select! { + /// _ = child_token.cancelled() => { + /// // The token was cancelled + /// 5 + /// } + /// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => { + /// 99 + /// } + /// } + /// }); + /// + /// tokio::spawn(async move { + /// tokio::time::sleep(std::time::Duration::from_millis(10)).await; + /// token.cancel(); + /// }); + /// + /// assert_eq!(5, join_handle.await.unwrap()); + /// } + /// ``` + pub fn child_token(&self) -> CancellationToken { + CancellationToken { + inner: tree_node::child_node(&self.inner), + } + } + + /// Cancel the [`CancellationToken`] and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + /// + /// Be aware that cancellation is not an atomic operation. It is possible + /// for another thread running in parallel with a call to `cancel` to first + /// receive `true` from `is_cancelled` on one child node, and then receive + /// `false` from `is_cancelled` on another child node. However, once the + /// call to `cancel` returns, all child nodes have been fully cancelled. + pub fn cancel(&self) { + tree_node::cancel(&self.inner); + } + + /// Returns `true` if the `CancellationToken` is cancelled. + pub fn is_cancelled(&self) -> bool { + tree_node::is_cancelled(&self.inner) + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + cancellation_token: self, + future: self.inner.notified(), + } + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// The function takes self by value and returns a future that owns the + /// token. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned { + WaitForCancellationFutureOwned::new(self) + } + + /// Creates a `DropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> DropGuard { + DropGuard { inner: Some(self) } + } +} + +// ===== impl WaitForCancellationFuture ===== + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + this.future.set(this.cancellation_token.inner.notified()); + } + } +} + +// ===== impl WaitForCancellationFutureOwned ===== + +impl core::fmt::Debug for WaitForCancellationFutureOwned { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFutureOwned").finish() + } +} + +impl WaitForCancellationFutureOwned { + fn new(cancellation_token: CancellationToken) -> Self { + WaitForCancellationFutureOwned { + // cancellation_token holds a heap allocation and is guaranteed to have a + // stable deref, thus it would be ok to move the cancellation_token while + // the future holds a reference to it. + // + // # Safety + // + // cancellation_token is dropped after future due to the field ordering. + future: unsafe { Self::new_future(&cancellation_token) }, + cancellation_token, + } + } + + /// # Safety + /// The returned future must be destroyed before the cancellation token is + /// destroyed. + unsafe fn new_future( + cancellation_token: &CancellationToken, + ) -> tokio::sync::futures::Notified<'static> { + let inner_ptr = Arc::as_ptr(&cancellation_token.inner); + // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains + // valid until the strong count of the Arc drops to zero, and the caller + // guarantees that they will drop the future before that happens. + (*inner_ptr).notified() + } +} + +impl Future for WaitForCancellationFutureOwned { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + + loop { + if this.cancellation_token.is_cancelled() { + return Poll::Ready(()); + } + + // No wakeups can be lost here because there is always a call to + // `is_cancelled` between the creation of the future and the call to + // `poll`, and the code that sets the cancelled flag does so before + // waking the `Notified`. + if this.future.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + + // # Safety + // + // cancellation_token is dropped after future due to the field ordering. + this.future + .set(unsafe { Self::new_future(this.cancellation_token) }); + } + } +} diff --git a/src/sync/cancellation_token/guard.rs b/src/sync/cancellation_token/guard.rs new file mode 100644 index 0000000..54ed7ea --- /dev/null +++ b/src/sync/cancellation_token/guard.rs @@ -0,0 +1,27 @@ +use crate::sync::CancellationToken; + +/// A wrapper for cancellation token which automatically cancels +/// it on drop. It is created using `drop_guard` method on the `CancellationToken`. +#[derive(Debug)] +pub struct DropGuard { + pub(super) inner: Option, +} + +impl DropGuard { + /// Returns stored cancellation token and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(mut self) -> CancellationToken { + self.inner + .take() + .expect("`inner` can be only None in a destructor") + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(inner) = &self.inner { + inner.cancel(); + } + } +} diff --git a/src/sync/cancellation_token/tree_node.rs b/src/sync/cancellation_token/tree_node.rs new file mode 100644 index 0000000..8f97dee --- /dev/null +++ b/src/sync/cancellation_token/tree_node.rs @@ -0,0 +1,373 @@ +//! This mod provides the logic for the inner tree structure of the CancellationToken. +//! +//! CancellationTokens are only light handles with references to TreeNode. +//! All the logic is actually implemented in the TreeNode. +//! +//! A TreeNode is part of the cancellation tree and may have one parent and an arbitrary number of +//! children. +//! +//! A TreeNode can receive the request to perform a cancellation through a CancellationToken. +//! This cancellation request will cancel the node and all of its descendants. +//! +//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no +//! more CancellationTokens pointing to it any more), it gets removed from the tree, to keep the +//! tree as small as possible. +//! +//! # Invariants +//! +//! Those invariants shall be true at any time. +//! +//! 1. A node that has no parents and no handles can no longer be cancelled. +//! This is important during both cancellation and refcounting. +//! +//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A. +//! This is important for deadlock safety, as it is used for lock order. +//! Node B can only become the child of node A in two ways: +//! - being created with `child_node()`, in which case it is trivially true that +//! node A already existed when node B was created +//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()` +//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C +//! was younger than A, therefore B is also younger than A. +//! +//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of +//! node A. It is important to always restore that invariant before dropping the lock of a node. +//! +//! # Deadlock safety +//! +//! We always lock in the order of creation time. We can prove this through invariant #2. +//! Specifically, through invariant #2, we know that we always have to lock a parent +//! before its child. +//! +use crate::loom::sync::{Arc, Mutex, MutexGuard}; + +/// A node of the cancellation tree structure +/// +/// The actual data it holds is wrapped inside a mutex for synchronization. +pub(crate) struct TreeNode { + inner: Mutex, + waker: tokio::sync::Notify, +} +impl TreeNode { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + } + } + + pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> { + self.waker.notified() + } +} + +/// The data contained inside a TreeNode. +/// +/// This struct exists so that the data of the node can be wrapped +/// in a Mutex. +struct Inner { + parent: Option>, + parent_idx: usize, + children: Vec>, + is_cancelled: bool, + num_handles: usize, +} + +/// Returns whether or not the node is cancelled +pub(crate) fn is_cancelled(node: &Arc) -> bool { + node.inner.lock().unwrap().is_cancelled +} + +/// Creates a child node +pub(crate) fn child_node(parent: &Arc) -> Arc { + let mut locked_parent = parent.inner.lock().unwrap(); + + // Do not register as child if we are already cancelled. + // Cancelled trees can never be uncancelled and therefore + // need no connection to parents or children any more. + if locked_parent.is_cancelled { + return Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: None, + parent_idx: 0, + children: vec![], + is_cancelled: true, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + } + + let child = Arc::new(TreeNode { + inner: Mutex::new(Inner { + parent: Some(parent.clone()), + parent_idx: locked_parent.children.len(), + children: vec![], + is_cancelled: false, + num_handles: 1, + }), + waker: tokio::sync::Notify::new(), + }); + + locked_parent.children.push(child.clone()); + + child +} + +/// Disconnects the given parent from all of its children. +/// +/// Takes a reference to [Inner] to make sure the parent is already locked. +fn disconnect_children(node: &mut Inner) { + for child in std::mem::take(&mut node.children) { + let mut locked_child = child.inner.lock().unwrap(); + locked_child.parent_idx = 0; + locked_child.parent = None; + } +} + +/// Figures out the parent of the node and locks the node and its parent atomically. +/// +/// The basic principle of preventing deadlocks in the tree is +/// that we always lock the parent first, and then the child. +/// For more info look at *deadlock safety* and *invariant #2*. +/// +/// Sadly, it's impossible to figure out the parent of a node without +/// locking it. To then achieve locking order consistency, the node +/// has to be unlocked before the parent gets locked. +/// This leaves a small window where we already assume that we know the parent, +/// but neither the parent nor the node is locked. Therefore, the parent could change. +/// +/// To prevent that this problem leaks into the rest of the code, it is abstracted +/// in this function. +/// +/// The locked child and optionally its locked parent, if a parent exists, get passed +/// to the `func` argument via (node, None) or (node, Some(parent)). +fn with_locked_node_and_parent(node: &Arc, func: F) -> Ret +where + F: FnOnce(MutexGuard<'_, Inner>, Option>) -> Ret, +{ + let mut potential_parent = { + let locked_node = node.inner.lock().unwrap(); + match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => return func(locked_node, None), + } + }; + + loop { + // Deadlock safety: + // + // Due to invariant #2, we know that we have to lock the parent first, and then the child. + // This is true even if the potential_parent is no longer the current parent or even its + // sibling, as the invariant still holds. + let locked_parent = potential_parent.inner.lock().unwrap(); + let locked_node = node.inner.lock().unwrap(); + + let actual_parent = match locked_node.parent.clone() { + Some(parent) => parent, + // If we locked the node and its parent is `None`, we are in a valid state + // and can return. + None => { + // Was the wrong parent, so unlock it before calling `func` + drop(locked_parent); + return func(locked_node, None); + } + }; + + // Loop until we managed to lock both the node and its parent + if Arc::ptr_eq(&actual_parent, &potential_parent) { + return func(locked_node, Some(locked_parent)); + } + + // Drop locked_parent before reassigning to potential_parent, + // as potential_parent is borrowed in it + drop(locked_node); + drop(locked_parent); + + potential_parent = actual_parent; + } +} + +/// Moves all children from `node` to `parent`. +/// +/// `parent` MUST have been a parent of the node when they both got locked, +/// otherwise there is a potential for a deadlock as invariant #2 would be violated. +/// +/// To acquire the locks for node and parent, use [with_locked_node_and_parent]. +fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) { + // Pre-allocate in the parent, for performance + parent.children.reserve(node.children.len()); + + for child in std::mem::take(&mut node.children) { + { + let mut child_locked = child.inner.lock().unwrap(); + child_locked.parent = node.parent.clone(); + child_locked.parent_idx = parent.children.len(); + } + parent.children.push(child); + } +} + +/// Removes a child from the parent. +/// +/// `parent` MUST be the parent of `node`. +/// To acquire the locks for node and parent, use [with_locked_node_and_parent]. +fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) { + // Query the position from where to remove a node + let pos = node.parent_idx; + node.parent = None; + node.parent_idx = 0; + + // Unlock node, so that only one child at a time is locked. + // Otherwise we would violate the lock order (see 'deadlock safety') as we + // don't know the creation order of the child nodes + drop(node); + + // If `node` is the last element in the list, we don't need any swapping + if parent.children.len() == pos + 1 { + parent.children.pop().unwrap(); + } else { + // If `node` is not the last element in the list, we need to + // replace it with the last element + let replacement_child = parent.children.pop().unwrap(); + replacement_child.inner.lock().unwrap().parent_idx = pos; + parent.children[pos] = replacement_child; + } + + let len = parent.children.len(); + if 4 * len <= parent.children.capacity() { + // equal to: + // parent.children.shrink_to(2 * len); + // but shrink_to was not yet stabilized in our minimal compatible version + let old_children = std::mem::replace(&mut parent.children, Vec::with_capacity(2 * len)); + parent.children.extend(old_children); + } +} + +/// Increases the reference count of handles. +pub(crate) fn increase_handle_refcount(node: &Arc) { + let mut locked_node = node.inner.lock().unwrap(); + + // Once no handles are left over, the node gets detached from the tree. + // There should never be a new handle once all handles are dropped. + assert!(locked_node.num_handles > 0); + + locked_node.num_handles += 1; +} + +/// Decreases the reference count of handles. +/// +/// Once no handle is left, we can remove the node from the +/// tree and connect its parent directly to its children. +pub(crate) fn decrease_handle_refcount(node: &Arc) { + let num_handles = { + let mut locked_node = node.inner.lock().unwrap(); + locked_node.num_handles -= 1; + locked_node.num_handles + }; + + if num_handles == 0 { + with_locked_node_and_parent(node, |mut node, parent| { + // Remove the node from the tree + match parent { + Some(mut parent) => { + // As we want to remove ourselves from the tree, + // we have to move the children to the parent, so that + // they still receive the cancellation event without us. + // Moving them does not violate invariant #1. + move_children_to_parent(&mut node, &mut parent); + + // Remove the node from the parent + remove_child(&mut parent, node); + } + None => { + // Due to invariant #1, we can assume that our + // children can no longer be cancelled through us. + // (as we now have neither a parent nor handles) + // Therefore we can disconnect them. + disconnect_children(&mut node); + } + } + }); + } +} + +/// Cancels a node and its children. +pub(crate) fn cancel(node: &Arc) { + let mut locked_node = node.inner.lock().unwrap(); + + if locked_node.is_cancelled { + return; + } + + // One by one, adopt grandchildren and then cancel and detach the child + while let Some(child) = locked_node.children.pop() { + // This can't deadlock because the mutex we are already + // holding is the parent of child. + let mut locked_child = child.inner.lock().unwrap(); + + // Detach the child from node + // No need to modify node.children, as the child already got removed with `.pop` + locked_child.parent = None; + locked_child.parent_idx = 0; + + // If child is already cancelled, detaching is enough + if locked_child.is_cancelled { + continue; + } + + // Cancel or adopt grandchildren + while let Some(grandchild) = locked_child.children.pop() { + // This can't deadlock because the two mutexes we are already + // holding is the parent and grandparent of grandchild. + let mut locked_grandchild = grandchild.inner.lock().unwrap(); + + // Detach the grandchild + locked_grandchild.parent = None; + locked_grandchild.parent_idx = 0; + + // If grandchild is already cancelled, detaching is enough + if locked_grandchild.is_cancelled { + continue; + } + + // For performance reasons, only adopt grandchildren that have children. + // Otherwise, just cancel them right away, no need for another iteration. + if locked_grandchild.children.is_empty() { + // Cancel the grandchild + locked_grandchild.is_cancelled = true; + locked_grandchild.children = Vec::new(); + drop(locked_grandchild); + grandchild.waker.notify_waiters(); + } else { + // Otherwise, adopt grandchild + locked_grandchild.parent = Some(node.clone()); + locked_grandchild.parent_idx = locked_node.children.len(); + drop(locked_grandchild); + locked_node.children.push(grandchild); + } + } + + // Cancel the child + locked_child.is_cancelled = true; + locked_child.children = Vec::new(); + drop(locked_child); + child.waker.notify_waiters(); + + // Now the child is cancelled and detached and all its children are adopted. + // Just continue until all (including adopted) children are cancelled and detached. + } + + // Cancel the node itself. + locked_node.is_cancelled = true; + locked_node.children = Vec::new(); + drop(locked_node); + node.waker.notify_waiters(); +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 0000000..1ae1b7e --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,15 @@ +//! Synchronization primitives + +mod cancellation_token; +pub use cancellation_token::{ + guard::DropGuard, CancellationToken, WaitForCancellationFuture, WaitForCancellationFutureOwned, +}; + +mod mpsc; +pub use mpsc::{PollSendError, PollSender}; + +mod poll_semaphore; +pub use poll_semaphore::PollSemaphore; + +mod reusable_box; +pub use reusable_box::ReusableBoxFuture; diff --git a/src/sync/mpsc.rs b/src/sync/mpsc.rs new file mode 100644 index 0000000..55ed5c4 --- /dev/null +++ b/src/sync/mpsc.rs @@ -0,0 +1,284 @@ +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, mem}; +use tokio::sync::mpsc::OwnedPermit; +use tokio::sync::mpsc::Sender; + +use super::ReusableBoxFuture; + +/// Error returned by the `PollSender` when the channel is closed. +#[derive(Debug)] +pub struct PollSendError(Option); + +impl PollSendError { + /// Consumes the stored value, if any. + /// + /// If this error was encountered when calling `start_send`/`send_item`, this will be the item + /// that the caller attempted to send. Otherwise, it will be `None`. + pub fn into_inner(self) -> Option { + self.0 + } +} + +impl fmt::Display for PollSendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +impl std::error::Error for PollSendError {} + +#[derive(Debug)] +enum State { + Idle(Sender), + Acquiring, + ReadyToSend(OwnedPermit), + Closed, +} + +/// A wrapper around [`mpsc::Sender`] that can be polled. +/// +/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender +#[derive(Debug)] +pub struct PollSender { + sender: Option>, + state: State, + acquire: ReusableBoxFuture<'static, Result, PollSendError>>, +} + +// Creates a future for acquiring a permit from the underlying channel. This is used to ensure +// there's capacity for a send to complete. +// +// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to +// ReusableBoxFuture has the same underlying type, and hence the same size and alignment. +async fn make_acquire_future( + data: Option>, +) -> Result, PollSendError> { + match data { + Some(sender) => sender + .reserve_owned() + .await + .map_err(|_| PollSendError(None)), + None => unreachable!("this future should not be pollable in this state"), + } +} + +impl PollSender { + /// Creates a new `PollSender`. + pub fn new(sender: Sender) -> Self { + Self { + sender: Some(sender.clone()), + state: State::Idle(sender), + acquire: ReusableBoxFuture::new(make_acquire_future(None)), + } + } + + fn take_state(&mut self) -> State { + mem::replace(&mut self.state, State::Closed) + } + + /// Attempts to prepare the sender to receive a value. + /// + /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to + /// `send_item`. + /// + /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, + /// by reserving a slot in the channel for the item to be sent. If this method returns + /// `Poll::Pending`, the current task is registered to be notified (via + /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + let (result, next_state) = match self.take_state() { + State::Idle(sender) => { + // Start trying to acquire a permit to reserve a slot for our send, and + // immediately loop back around to poll it the first time. + self.acquire.set(make_acquire_future(Some(sender))); + (None, State::Acquiring) + } + State::Acquiring => match self.acquire.poll(cx) { + // Channel has capacity. + Poll::Ready(Ok(permit)) => { + (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) + } + // Channel is closed. + Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), + // Channel doesn't have capacity yet, so we need to wait. + Poll::Pending => (Some(Poll::Pending), State::Acquiring), + }, + // We're closed, either by choice or because the underlying sender was closed. + s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), + // We're already ready to send an item. + s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), + }; + + self.state = next_state; + if let Some(result) = result { + return result; + } + } + } + + /// Sends an item to the channel. + /// + /// Before calling `send_item`, `poll_reserve` must be called with a successful return + /// value of `Poll::Ready(Ok(()))`. + /// + /// # Errors + /// + /// If the channel is closed, an error will be returned. This is a permanent state. + /// + /// # Panics + /// + /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method + /// will panic. + #[track_caller] + pub fn send_item(&mut self, value: T) -> Result<(), PollSendError> { + let (result, next_state) = match self.take_state() { + State::Idle(_) | State::Acquiring => { + panic!("`send_item` called without first calling `poll_reserve`") + } + // We have a permit to send our item, so go ahead, which gets us our sender back. + State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), + // We're closed, either by choice or because the underlying sender was closed. + State::Closed => (Err(PollSendError(Some(value))), State::Closed), + }; + + // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. + self.state = if self.sender.is_some() { + next_state + } else { + State::Closed + }; + result + } + + /// Checks whether this sender is been closed. + /// + /// The underlying channel that this sender was wrapping may still be open. + pub fn is_closed(&self) -> bool { + matches!(self.state, State::Closed) || self.sender.is_none() + } + + /// Gets a reference to the `Sender` of the underlying channel. + /// + /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender + /// was wrapping may still be open. + pub fn get_ref(&self) -> Option<&Sender> { + self.sender.as_ref() + } + + /// Closes this sender. + /// + /// No more messages will be able to be sent from this sender, but the underlying channel will + /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. + /// + /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made + /// to `send_item` in order to consume the reserved slot. After that, no further sends will be + /// possible. If you do not intend to send another item, you can release the reserved slot back + /// to the underlying sender by calling [`abort_send`]. + /// + /// [`abort_send`]: crate::sync::PollSender::abort_send + /// [`Receiver`]: tokio::sync::mpsc::Receiver + pub fn close(&mut self) { + // Mark ourselves officially closed by dropping our main sender. + self.sender = None; + + // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly + // transition to the closed state. Otherwise, leave the existing permit in place for the + // caller if they want to complete the send. + match self.state { + State::Idle(_) => self.state = State::Closed, + State::Acquiring => { + self.acquire.set(make_acquire_future(None)); + self.state = State::Closed; + } + _ => {} + } + } + + /// Aborts the current in-progress send, if any. + /// + /// Returns `true` if a send was aborted. If the sender was closed prior to calling + /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be + /// ready to attempt another send. + pub fn abort_send(&mut self) -> bool { + // We may have been closed in the meantime, after a call to `poll_reserve` already + // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the + // closed state when we actually abort a send, rather than resetting ourselves back to idle. + + let (result, next_state) = match self.take_state() { + // We're currently trying to reserve a slot to send into. + State::Acquiring => { + // Replacing the future drops the in-flight one. + self.acquire.set(make_acquire_future(None)); + + // If we haven't closed yet, we have to clone our stored sender since we have no way + // to get it back from the acquire future we just dropped. + let state = match self.sender.clone() { + Some(sender) => State::Idle(sender), + None => State::Closed, + }; + (true, state) + } + // We got the permit. If we haven't closed yet, get the sender back. + State::ReadyToSend(permit) => { + let state = if self.sender.is_some() { + State::Idle(permit.release()) + } else { + State::Closed + }; + (true, state) + } + s => (false, s), + }; + + self.state = next_state; + result + } +} + +impl Clone for PollSender { + /// Clones this `PollSender`. + /// + /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. + fn clone(&self) -> PollSender { + let (sender, state) = match self.sender.clone() { + Some(sender) => (Some(sender.clone()), State::Idle(sender)), + None => (None, State::Closed), + }; + + Self { + sender, + state, + // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not + // compatible with the transitive bounds required by `Sender`. + acquire: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl Sink for PollSender { + type Error = PollSendError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).poll_reserve(cx) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::into_inner(self).send_item(item) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).close(); + Poll::Ready(Ok(())) + } +} diff --git a/src/sync/poll_semaphore.rs b/src/sync/poll_semaphore.rs new file mode 100644 index 0000000..6b44574 --- /dev/null +++ b/src/sync/poll_semaphore.rs @@ -0,0 +1,171 @@ +use futures_core::{ready, Stream}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +use super::ReusableBoxFuture; + +/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. +/// +/// [`Semaphore`]: tokio::sync::Semaphore +pub struct PollSemaphore { + semaphore: Arc, + permit_fut: Option<( + u32, // The number of permits requested. + ReusableBoxFuture<'static, Result>, + )>, +} + +impl PollSemaphore { + /// Create a new `PollSemaphore`. + pub fn new(semaphore: Arc) -> Self { + Self { + semaphore, + permit_fut: None, + } + } + + /// Closes the semaphore. + pub fn close(&self) { + self.semaphore.close() + } + + /// Obtain a clone of the inner semaphore. + pub fn clone_inner(&self) -> Arc { + self.semaphore.clone() + } + + /// Get back the inner semaphore. + pub fn into_inner(self) -> Arc { + self.semaphore + } + + /// Poll to acquire a permit from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when a permit becomes available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_acquire_many(cx, 1) + } + + /// Poll to acquire many permits from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when the permits become available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire_many( + &mut self, + cx: &mut Context<'_>, + permits: u32, + ) -> Poll> { + let permit_future = match self.permit_fut.as_mut() { + Some((prev_permits, fut)) if *prev_permits == permits => fut, + Some((old_permits, fut_box)) => { + // We're requesting a different number of permits, so replace the future + // and record the new amount. + let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); + fut_box.set(fut); + *old_permits = permits; + fut_box + } + None => { + // avoid allocations completely if we can grab a permit immediately + match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) { + Ok(permit) => return Poll::Ready(Some(permit)), + Err(TryAcquireError::Closed) => return Poll::Ready(None), + Err(TryAcquireError::NoPermits) => {} + } + + let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); + &mut self + .permit_fut + .get_or_insert((permits, ReusableBoxFuture::new(next_fut))) + .1 + } + }; + + let result = ready!(permit_future.poll(cx)); + + // Assume we'll request the same amount of permits in a subsequent call. + let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); + permit_future.set(next_fut); + + match result { + Ok(permit) => Poll::Ready(Some(permit)), + Err(_closed) => { + self.permit_fut = None; + Poll::Ready(None) + } + } + } + + /// Returns the current number of available permits. + /// + /// This is equivalent to the [`Semaphore::available_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits + pub fn available_permits(&self) -> usize { + self.semaphore.available_permits() + } + + /// Adds `n` new permits to the semaphore. + /// + /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function + /// will panic if the limit is exceeded. + /// + /// This is equivalent to the [`Semaphore::add_permits`] method on the + /// `tokio::sync::Semaphore` type. + /// + /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits + pub fn add_permits(&self, n: usize) { + self.semaphore.add_permits(n); + } +} + +impl Stream for PollSemaphore { + type Item = OwnedSemaphorePermit; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).poll_acquire(cx) + } +} + +impl Clone for PollSemaphore { + fn clone(&self) -> PollSemaphore { + PollSemaphore::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollSemaphore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollSemaphore") + .field("semaphore", &self.semaphore) + .finish() + } +} + +impl AsRef for PollSemaphore { + fn as_ref(&self) -> &Semaphore { + &self.semaphore + } +} diff --git a/src/sync/reusable_box.rs b/src/sync/reusable_box.rs new file mode 100644 index 0000000..1b8ef60 --- /dev/null +++ b/src/sync/reusable_box.rs @@ -0,0 +1,171 @@ +use std::alloc::Layout; +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +/// A reusable `Pin + Send + 'a>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub struct ReusableBoxFuture<'a, T> { + boxed: Pin + Send + 'a>>, +} + +impl<'a, T> ReusableBoxFuture<'a, T> { + /// Create a new `ReusableBoxFuture` containing the provided future. + pub fn new(future: F) -> Self + where + F: Future + Send + 'a, + { + Self { + boxed: Box::pin(future), + } + } + + /// Replace the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub fn set(&mut self, future: F) + where + F: Future + Send + 'a, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replace the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub fn try_set(&mut self, future: F) -> Result<(), F> + where + F: Future + Send + 'a, + { + // If we try to inline the contents of this function, the type checker complains because + // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an + // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound + // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`. + #[inline(always)] + fn real_try_set<'a, F>( + this: &mut ReusableBoxFuture<'a, F::Output>, + future: F, + ) -> Result<(), F> + where + F: Future + Send + 'a, + { + // future::Pending is a ZST so this never allocates. + let boxed = mem::replace(&mut this.boxed, Box::pin(Pending(PhantomData))); + reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed)) + } + + real_try_set(self, future) + } + + /// Get a pinned reference to the underlying future. + pub fn get_pin(&mut self) -> Pin<&mut (dyn Future + Send)> { + self.boxed.as_mut() + } + + /// Poll the future stored inside this box. + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll { + self.get_pin().poll(cx) + } +} + +impl Future for ReusableBoxFuture<'_, T> { + type Output = T; + + /// Poll the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl Sync for ReusableBoxFuture<'_, T> {} + +impl fmt::Debug for ReusableBoxFuture<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} + +fn reuse_pin_box(boxed: Pin>, new_value: U, callback: F) -> Result +where + F: FnOnce(Box) -> O, +{ + let layout = Layout::for_value::(&*boxed); + if layout != Layout::new::() { + return Err(new_value); + } + + // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we + // always drop the `T`. + let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) }); + + // When dropping the old value panics, we still want to call `callback` — so move the rest of + // the code into a guard type. + let guard = CallOnDrop::new(|| { + let raw: *mut U = raw.cast::(); + unsafe { raw.write(new_value) }; + + // SAFETY: + // - `T` and `U` have the same layout. + // - `raw` comes from a `Box` that uses the same allocator as this one. + // - `raw` points to a valid instance of `U` (we just wrote it in). + let boxed = unsafe { Box::from_raw(raw) }; + + callback(boxed) + }); + + // Drop the old value. + unsafe { ptr::drop_in_place(raw) }; + + // Run the rest of the code. + Ok(guard.call()) +} + +struct CallOnDrop O> { + f: ManuallyDrop, +} + +impl O> CallOnDrop { + fn new(f: F) -> Self { + let f = ManuallyDrop::new(f); + Self { f } + } + fn call(self) -> O { + let mut this = ManuallyDrop::new(self); + let f = unsafe { ManuallyDrop::take(&mut this.f) }; + f() + } +} + +impl O> Drop for CallOnDrop { + fn drop(&mut self) { + let f = unsafe { ManuallyDrop::take(&mut self.f) }; + f(); + } +} + +/// The same as `std::future::Pending`; we can't use that type directly because on rustc +/// versions <1.60 it didn't unconditionally implement `Send`. +// FIXME: use `std::future::Pending` once the MSRV is >=1.60 +struct Pending(PhantomData T>); + +impl Future for Pending { + type Output = T; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Pending + } +} diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs new file mode 100644 index 0000000..b57503d --- /dev/null +++ b/src/sync/tests/loom_cancellation_token.rs @@ -0,0 +1,176 @@ +use crate::sync::CancellationToken; + +use loom::{future::block_on, thread}; +use tokio_test::assert_ok; + +#[test] +fn cancel_token() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_token_owned() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled_owned().await; + }); + }); + + let th2 = thread::spawn(move || { + token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + +#[test] +fn cancel_with_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + let child_token = token.child_token(); + + let th1 = thread::spawn(move || { + block_on(async { + token1.cancelled().await; + }); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + block_on(async { + child_token.cancelled().await; + }); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_no_child() { + loom::model(|| { + let token = CancellationToken::new(); + let token1 = token.clone(); + let token2 = token.clone(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(token2); + }); + + let th3 = thread::spawn(move || { + drop(token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_token_with_children() { + loom::model(|| { + let token1 = CancellationToken::new(); + let child_token1 = token1.child_token(); + let child_token2 = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + drop(child_token1); + }); + + let th3 = thread::spawn(move || { + drop(child_token2); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn drop_and_cancel_token() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + drop(child_token); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} + +#[test] +fn cancel_parent_and_child() { + loom::model(|| { + let token1 = CancellationToken::new(); + let token2 = token1.clone(); + let child_token = token1.child_token(); + + let th1 = thread::spawn(move || { + drop(token1); + }); + + let th2 = thread::spawn(move || { + token2.cancel(); + }); + + let th3 = thread::spawn(move || { + child_token.cancel(); + }); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + assert_ok!(th3.join()); + }); +} diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/sync/tests/mod.rs @@ -0,0 +1 @@ + diff --git a/src/task/join_map.rs b/src/task/join_map.rs new file mode 100644 index 0000000..c6bf5bc --- /dev/null +++ b/src/task/join_map.rs @@ -0,0 +1,807 @@ +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use std::borrow::Borrow; +use std::collections::hash_map::RandomState; +use std::fmt; +use std::future::Future; +use std::hash::{BuildHasher, Hash, Hasher}; +use tokio::runtime::Handle; +use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; + +/// A collection of tasks spawned on a Tokio runtime, associated with hash map +/// keys. +/// +/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the +/// addition of a set of keys associated with each task. These keys allow +/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the +/// `JoinMap` based on their keys, or [test whether a task corresponding to a +/// given key exists][contains] in the `JoinMap`. +/// +/// In addition, when tasks in the `JoinMap` complete, they will return the +/// associated key along with the value returned by the task, if any. +/// +/// A `JoinMap` can be used to await the completion of some or all of the tasks +/// in the map. The map is not ordered, and the tasks will be returned in the +/// order they complete. +/// +/// All of the tasks must have the same return type `V`. +/// +/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted. +/// +/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the +/// documentation on unstable features][unstable] for details on how to enable +/// Tokio's unstable features. +/// +/// # Examples +/// +/// Spawn multiple tasks and wait for them: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// for i in 0..10 { +/// // Spawn a task on the `JoinMap` with `i` as its key. +/// map.spawn(i, async move { /* ... */ }); +/// } +/// +/// let mut seen = [false; 10]; +/// +/// // When a task completes, `join_next` returns the task's key along +/// // with its output. +/// while let Some((key, res)) = map.join_next().await { +/// seen[key] = true; +/// assert!(res.is_ok(), "task {} completed successfully!", key); +/// } +/// +/// for i in 0..10 { +/// assert!(seen[i]); +/// } +/// } +/// ``` +/// +/// Cancel tasks based on their keys: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// map.spawn("hello world", async move { /* ... */ }); +/// map.spawn("goodbye world", async move { /* ... */}); +/// +/// // Look up the "goodbye world" task in the map and abort it. +/// let aborted = map.abort("goodbye world"); +/// +/// // `JoinMap::abort` returns `true` if a task existed for the +/// // provided key. +/// assert!(aborted); +/// +/// while let Some((key, res)) = map.join_next().await { +/// if key == "goodbye world" { +/// // The aborted task should complete with a cancelled `JoinError`. +/// assert!(res.unwrap_err().is_cancelled()); +/// } else { +/// // Other tasks should complete normally. +/// assert!(res.is_ok()); +/// } +/// } +/// } +/// ``` +/// +/// [`JoinSet`]: tokio::task::JoinSet +/// [unstable]: tokio#unstable-features +/// [abort]: fn@Self::abort +/// [abort_matching]: fn@Self::abort_matching +/// [contains]: fn@Self::contains_key +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +pub struct JoinMap { + /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, + /// indexed by their keys and task IDs. + /// + /// The [`Key`] type contains both the task's `K`-typed key provided when + /// spawning tasks, and the task's IDs. The IDs are stored here to resolve + /// hash collisions when looking up tasks based on their pre-computed hash + /// (as stored in the `hashes_by_task` map). + tasks_by_key: HashMap, AbortHandle, S>, + + /// A map from task IDs to the hash of the key associated with that task. + /// + /// This map is used to perform reverse lookups of tasks in the + /// `tasks_by_key` map based on their task IDs. When a task terminates, the + /// ID is provided to us by the `JoinSet`, so we can look up the hash value + /// of that task's key, and then remove it from the `tasks_by_key` map using + /// the raw hash code, resolving collisions by comparing task IDs. + hashes_by_task: HashMap, + + /// The [`JoinSet`] that awaits the completion of tasks spawned on this + /// `JoinMap`. + tasks: JoinSet, +} + +/// A [`JoinMap`] key. +/// +/// This holds both a `K`-typed key (the actual key as seen by the user), _and_ +/// a task ID, so that hash collisions between `K`-typed keys can be resolved +/// using either `K`'s `Eq` impl *or* by checking the task IDs. +/// +/// This allows looking up a task using either an actual key (such as when the +/// user queries the map with a key), *or* using a task ID and a hash (such as +/// when removing completed tasks from the map). +#[derive(Debug)] +struct Key { + key: K, + id: Id, +} + +impl JoinMap { + /// Creates a new empty `JoinMap`. + /// + /// The `JoinMap` is initially created with a capacity of 0, so it will not + /// allocate until a task is first spawned on it. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::new(); + /// ``` + #[inline] + #[must_use] + pub fn new() -> Self { + Self::with_hasher(RandomState::new()) + } + + /// Creates an empty `JoinMap` with the specified capacity. + /// + /// The `JoinMap` will be able to hold at least `capacity` tasks without + /// reallocating. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10); + /// ``` + #[inline] + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + JoinMap::with_capacity_and_hasher(capacity, Default::default()) + } +} + +impl JoinMap { + /// Creates an empty `JoinMap` which will use the given hash builder to hash + /// keys. + /// + /// The created map has the default initial capacity. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow `JoinMap` to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap` to be useful, see its documentation for details. + #[inline] + #[must_use] + pub fn with_hasher(hash_builder: S) -> Self { + Self::with_capacity_and_hasher(0, hash_builder) + } + + /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder` + /// to hash the keys. + /// + /// The `JoinMap` will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow HashMaps to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap`to be useful, see its documentation for details. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// use std::collections::hash_map::RandomState; + /// + /// let s = RandomState::new(); + /// let mut map = JoinMap::with_capacity_and_hasher(10, s); + /// map.spawn(1, async move { "hello world!" }); + /// # } + /// ``` + #[inline] + #[must_use] + pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { + Self { + tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), + hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), + tasks: JoinSet::new(), + } + } + + /// Returns the number of tasks currently in the `JoinMap`. + pub fn len(&self) -> usize { + let len = self.tasks_by_key.len(); + debug_assert_eq!(len, self.hashes_by_task.len()); + len + } + + /// Returns whether the `JoinMap` is empty. + pub fn is_empty(&self) -> bool { + let empty = self.tasks_by_key.is_empty(); + debug_assert_eq!(empty, self.hashes_by_task.is_empty()); + empty + } + + /// Returns the number of tasks the map can hold without reallocating. + /// + /// This number is a lower bound; the `JoinMap` might be able to hold + /// more, but is guaranteed to be able to hold at least this many. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let map: JoinMap = JoinMap::with_capacity(100); + /// assert!(map.capacity() >= 100); + /// ``` + #[inline] + pub fn capacity(&self) -> usize { + let capacity = self.tasks_by_key.capacity(); + debug_assert_eq!(capacity, self.hashes_by_task.capacity()); + capacity + } +} + +impl JoinMap +where + K: Hash + Eq, + V: 'static, + S: BuildHasher, +{ + /// Spawn the provided task and store it in this `JoinMap` with the provided + /// key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn(&mut self, key: K, task: F) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn(task); + self.insert(key, task) + } + + /// Spawn the provided task on the provided runtime and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_on(&mut self, key: K, task: F, handle: &Handle) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_on(task, handle); + self.insert(key, task); + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_local(&mut self, key: K, task: F) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local(task); + self.insert(key, task); + } + + /// Spawn the provided task on the provided [`LocalSet`] and store it in + /// this `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_next`]: Self::join_next + #[track_caller] + pub fn spawn_local_on(&mut self, key: K, task: F, local_set: &LocalSet) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local_on(task, local_set); + self.insert(key, task) + } + + fn insert(&mut self, key: K, abort: AbortHandle) { + let hash = self.hash(&key); + let id = abort.id(); + let map_key = Key { id, key }; + + // Insert the new key into the map of tasks by keys. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.key == map_key.key); + match entry { + RawEntryMut::Occupied(mut occ) => { + // There was a previous task spawned with the same key! Cancel + // that task, and remove its ID from the map of hashes by task IDs. + let Key { id: prev_id, .. } = occ.insert_key(map_key); + occ.insert(abort).abort(); + let _prev_hash = self.hashes_by_task.remove(&prev_id); + debug_assert_eq!(Some(hash), _prev_hash); + } + RawEntryMut::Vacant(vac) => { + vac.insert(map_key, abort); + } + }; + + // Associate the key's hash with this task's ID, for looking up tasks by ID. + let _prev = self.hashes_by_task.insert(id, hash); + debug_assert!(_prev.is_none(), "no prior task should have had the same ID"); + } + + /// Waits until one of the tasks in the map completes and returns its + /// output, along with the key corresponding to that task. + /// + /// Returns `None` if the map is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`] + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinMap`. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has + /// completed. The `value` is the return value of that ask, and `key` is + /// the key associated with the task. + /// * `Some((key, Err(err))` if one of the tasks in this JoinMap` has + /// panicked or been aborted. `key` is the key associated with the task + /// that panicked or was aborted. + /// * `None` if the `JoinMap` is empty. + /// + /// [`tokio::select!`]: tokio::select + pub async fn join_next(&mut self) -> Option<(K, Result)> { + let (res, id) = match self.tasks.join_next_with_id().await { + Some(Ok((id, output))) => (Ok(output), id), + Some(Err(e)) => { + let id = e.id(); + (Err(e), id) + } + None => return None, + }; + let key = self.remove_by_id(id)?; + Some((key, res)) + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in + /// a loop until it returns `None`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// `JoinMap` will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_next`]: fn@Self::join_next + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_next().await.is_some() {} + } + + /// Abort the task corresponding to the provided `key`. + /// + /// If this `JoinMap` contains a task corresponding to `key`, this method + /// will abort that task and return `true`. Otherwise, if no task exists for + /// `key`, this method returns `false`. + /// + /// # Examples + /// + /// Aborting a task by key: + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // Look up the "goodbye world" task in the map and abort it. + /// map.abort("goodbye world"); + /// + /// while let Some((key, res)) = map.join_next().await { + /// if key == "goodbye world" { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(res.is_ok()); + /// } + /// } + /// # } + /// ``` + /// + /// `abort` returns `true` if a task was aborted: + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // A task for the key "goodbye world" should exist in the map: + /// assert!(map.abort("goodbye world")); + /// + /// // Aborting a key that does not exist will return `false`: + /// assert!(!map.abort("goodbye universe")); + /// # } + /// ``` + pub fn abort(&mut self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + match self.get_by_key(key) { + Some((_, handle)) => { + handle.abort(); + true + } + None => false, + } + } + + /// Aborts all tasks with keys matching `predicate`. + /// + /// `predicate` is a function called with a reference to each key in the + /// map. If it returns `true` for a given key, the corresponding task will + /// be cancelled. + /// + /// # Examples + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # // use the current thread rt so that spawned tasks don't + /// # // complete in the background before they can be aborted. + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("hello san francisco", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye universe", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// + /// // Abort all tasks whose keys begin with "goodbye" + /// map.abort_matching(|key| key.starts_with("goodbye")); + /// + /// let mut seen = 0; + /// while let Some((key, res)) = map.join_next().await { + /// seen += 1; + /// if key.starts_with("goodbye") { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(key.starts_with("hello")); + /// assert!(res.is_ok()); + /// } + /// } + /// + /// // All spawned tasks should have completed. + /// assert_eq!(seen, 4); + /// # } + /// ``` + pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) { + // Note: this method iterates over the tasks and keys *without* removing + // any entries, so that the keys from aborted tasks can still be + // returned when calling `join_next` in the future. + for (Key { ref key, .. }, task) in &self.tasks_by_key { + if predicate(key) { + task.abort(); + } + } + } + + /// Returns `true` if this `JoinMap` contains a task for the provided key. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_next`], this method will still return `true`. + /// + /// [`join_next`]: fn@Self::join_next + pub fn contains_key(&self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + self.get_by_key(key).is_some() + } + + /// Returns `true` if this `JoinMap` contains a task with the provided + /// [task ID]. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_next`], this method will still return `true`. + /// + /// [`join_next`]: fn@Self::join_next + /// [task ID]: tokio::task::Id + pub fn contains_task(&self, task: &Id) -> bool { + self.get_by_id(task).is_some() + } + + /// Reserves capacity for at least `additional` more tasks to be spawned + /// on this `JoinMap` without reallocating for the map of task keys. The + /// collection may reserve more space to avoid frequent reallocations. + /// + /// Note that spawning a task will still cause an allocation for the task + /// itself. + /// + /// # Panics + /// + /// Panics if the new allocation size overflows [`usize`]. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap<&str, i32> = JoinMap::new(); + /// map.reserve(10); + /// ``` + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.tasks_by_key.reserve(additional); + self.hashes_by_task.reserve(additional); + } + + /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop + /// down as much as possible while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to_fit(); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to_fit(&mut self) { + self.hashes_by_task.shrink_to_fit(); + self.tasks_by_key.shrink_to_fit(); + } + + /// Shrinks the capacity of the map with a lower limit. It will drop + /// down no lower than the supplied limit while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// If the current capacity is less than the lower limit, this is a no-op. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to(10); + /// assert!(map.capacity() >= 10); + /// map.shrink_to(0); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to(&mut self, min_capacity: usize) { + self.hashes_by_task.shrink_to(min_capacity); + self.tasks_by_key.shrink_to(min_capacity) + } + + /// Look up a task in the map by its key, returning the key and abort handle. + fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key, &'map AbortHandle)> + where + Q: Hash + Eq, + K: Borrow, + { + let hash = self.hash(key); + self.tasks_by_key + .raw_entry() + .from_hash(hash, |k| k.key.borrow() == key) + } + + /// Look up a task in the map by its task ID, returning the key and abort handle. + fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key, &'map AbortHandle)> { + let hash = self.hashes_by_task.get(id)?; + self.tasks_by_key + .raw_entry() + .from_hash(*hash, |k| &k.id == id) + } + + /// Remove a task from the map by ID, returning the key for that task. + fn remove_by_id(&mut self, id: Id) -> Option { + // Get the hash for the given ID. + let hash = self.hashes_by_task.remove(&id)?; + + // Remove the entry for that hash. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.id == id); + let (Key { id: _key_id, key }, handle) = match entry { + RawEntryMut::Occupied(entry) => entry.remove_entry(), + _ => return None, + }; + debug_assert_eq!(_key_id, id); + debug_assert_eq!(id, handle.id()); + self.hashes_by_task.remove(&id); + Some(key) + } + + /// Returns the hash for a given key. + #[inline] + fn hash(&self, key: &Q) -> u64 + where + Q: Hash, + { + let mut hasher = self.tasks_by_key.hasher().build_hasher(); + key.hash(&mut hasher); + hasher.finish() + } +} + +impl JoinMap +where + V: 'static, +{ + /// Aborts all tasks on this `JoinMap`. + /// + /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete + /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty. + pub fn abort_all(&mut self) { + self.tasks.abort_all() + } + + /// Removes all tasks from this `JoinMap` without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the `JoinMap` + /// is dropped. They may still be aborted by key. + pub fn detach_all(&mut self) { + self.tasks.detach_all(); + self.tasks_by_key.clear(); + self.hashes_by_task.clear(); + } +} + +// Hand-written `fmt::Debug` implementation in order to avoid requiring `V: +// Debug`, since no value is ever actually stored in the map. +impl fmt::Debug for JoinMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // format the task keys and abort handles a little nicer by just + // printing the key and task ID pairs, without format the `Key` struct + // itself or the `AbortHandle`, which would just format the task's ID + // again. + struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap, AbortHandle, S>); + impl fmt::Debug for KeySet<'_, K, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.0.keys().map(|Key { key, id }| (key, id))) + .finish() + } + } + + f.debug_struct("JoinMap") + // The `tasks_by_key` map is the only one that contains information + // that's really worth formatting for the user, since it contains + // the tasks' keys and IDs. The other fields are basically + // implementation details. + .field("tasks", &KeySet(&self.tasks_by_key)) + .finish() + } +} + +impl Default for JoinMap { + fn default() -> Self { + Self::new() + } +} + +// === impl Key === + +impl Hash for Key { + // Don't include the task ID in the hash. + #[inline] + fn hash(&self, hasher: &mut H) { + self.key.hash(hasher); + } +} + +// Because we override `Hash` for this type, we must also override the +// `PartialEq` impl, so that all instances with the same hash are equal. +impl PartialEq for Key { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } +} + +impl Eq for Key {} diff --git a/src/task/mod.rs b/src/task/mod.rs new file mode 100644 index 0000000..de41dd5 --- /dev/null +++ b/src/task/mod.rs @@ -0,0 +1,12 @@ +//! Extra utilities for spawning tasks + +#[cfg(tokio_unstable)] +mod join_map; +#[cfg(not(target_os = "wasi"))] +mod spawn_pinned; +#[cfg(not(target_os = "wasi"))] +pub use spawn_pinned::LocalPoolHandle; + +#[cfg(tokio_unstable)] +#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] +pub use join_map::JoinMap; diff --git a/src/task/spawn_pinned.rs b/src/task/spawn_pinned.rs new file mode 100644 index 0000000..b4102ec --- /dev/null +++ b/src/task/spawn_pinned.rs @@ -0,0 +1,436 @@ +use futures_util::future::{AbortHandle, Abortable}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::future::Future; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::runtime::Builder; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot; +use tokio::task::{spawn_local, JoinHandle, LocalSet}; + +/// A cloneable handle to a local pool, used for spawning `!Send` tasks. +/// +/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread +/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will +/// execute on the same thread) inside the Future you supply to the various spawn methods +/// of `LocalPoolHandle`, +/// +/// [`tokio::task::LocalSet`]: tokio::task::LocalSet +/// [`tokio::task::spawn_local`]: tokio::task::spawn_local +/// +/// # Examples +/// +/// ``` +/// use std::rc::Rc; +/// use tokio::{self, task }; +/// use tokio_util::task::LocalPoolHandle; +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let pool = LocalPoolHandle::new(5); +/// +/// let output = pool.spawn_pinned(|| { +/// // `data` is !Send + !Sync +/// let data = Rc::new("local data"); +/// let data_clone = data.clone(); +/// +/// async move { +/// task::spawn_local(async move { +/// println!("{}", data_clone); +/// }); +/// +/// data.to_string() +/// } +/// }).await.unwrap(); +/// println!("output: {}", output); +/// } +/// ``` +/// +#[derive(Clone)] +pub struct LocalPoolHandle { + pool: Arc, +} + +impl LocalPoolHandle { + /// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this + /// pool via [`LocalPoolHandle::spawn_pinned`]. + /// + /// # Panics + /// + /// Panics if the pool size is less than one. + #[track_caller] + pub fn new(pool_size: usize) -> LocalPoolHandle { + assert!(pool_size > 0); + + let workers = (0..pool_size) + .map(|_| LocalWorkerHandle::new_worker()) + .collect(); + + let pool = Arc::new(LocalPool { workers }); + + LocalPoolHandle { pool } + } + + /// Returns the number of threads of the Pool. + #[inline] + pub fn num_threads(&self) -> usize { + self.pool.workers.len() + } + + /// Returns the number of tasks scheduled on each worker. The indices of the + /// worker threads correspond to the indices of the returned `Vec`. + pub fn get_task_loads_for_each_worker(&self) -> Vec { + self.pool + .workers + .iter() + .map(|worker| worker.task_count.load(Ordering::SeqCst)) + .collect::>() + } + + /// Spawn a task onto a worker thread and pin it there so it can't be moved + /// off of the thread. Note that the future is not [`Send`], but the + /// [`FnOnce`] which creates it is. + /// + /// # Examples + /// ``` + /// use std::rc::Rc; + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// // Create the local pool + /// let pool = LocalPoolHandle::new(1); + /// + /// // Spawn a !Send future onto the pool and await it + /// let output = pool + /// .spawn_pinned(|| { + /// // Rc is !Send + !Sync + /// let local_data = Rc::new("test"); + /// + /// // This future holds an Rc, so it is !Send + /// async move { local_data.to_string() } + /// }) + /// .await + /// .unwrap(); + /// + /// assert_eq!(output, "test"); + /// } + /// ``` + pub fn spawn_pinned(&self, create_task: F) -> JoinHandle + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool + .spawn_pinned(create_task, WorkerChoice::LeastBurdened) + } + + /// Differs from `spawn_pinned` only in that you can choose a specific worker thread + /// of the pool, whereas `spawn_pinned` chooses the worker with the smallest + /// number of tasks scheduled. + /// + /// A worker thread is chosen by index. Indices are 0 based and the largest index + /// is given by `num_threads() - 1` + /// + /// # Panics + /// + /// This method panics if the index is out of bounds. + /// + /// # Examples + /// + /// This method can be used to spawn a task on all worker threads of the pool: + /// + /// ``` + /// use tokio_util::task::LocalPoolHandle; + /// + /// #[tokio::main] + /// async fn main() { + /// const NUM_WORKERS: usize = 3; + /// let pool = LocalPoolHandle::new(NUM_WORKERS); + /// let handles = (0..pool.num_threads()) + /// .map(|worker_idx| { + /// pool.spawn_pinned_by_idx( + /// || { + /// async { + /// "test" + /// } + /// }, + /// worker_idx, + /// ) + /// }) + /// .collect::>(); + /// + /// for handle in handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// + #[track_caller] + pub fn spawn_pinned_by_idx(&self, create_task: F, idx: usize) -> JoinHandle + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + self.pool + .spawn_pinned(create_task, WorkerChoice::ByIdx(idx)) + } +} + +impl Debug for LocalPoolHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("LocalPoolHandle") + } +} + +enum WorkerChoice { + LeastBurdened, + ByIdx(usize), +} + +struct LocalPool { + workers: Vec, +} + +impl LocalPool { + /// Spawn a `?Send` future onto a worker + #[track_caller] + fn spawn_pinned( + &self, + create_task: F, + worker_choice: WorkerChoice, + ) -> JoinHandle + where + F: FnOnce() -> Fut, + F: Send + 'static, + Fut: Future + 'static, + Fut::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + let (worker, job_guard) = match worker_choice { + WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(), + WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx), + }; + let worker_spawner = worker.spawner.clone(); + + // Spawn a future onto the worker's runtime so we can immediately return + // a join handle. + worker.runtime_handle.spawn(async move { + // Move the job guard into the task + let _job_guard = job_guard; + + // Propagate aborts via Abortable/AbortHandle + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let _abort_guard = AbortGuard(abort_handle); + + // Inside the future we can't run spawn_local yet because we're not + // in the context of a LocalSet. We need to send create_task to the + // LocalSet task for spawning. + let spawn_task = Box::new(move || { + // Once we're in the LocalSet context we can call spawn_local + let join_handle = + spawn_local( + async move { Abortable::new(create_task(), abort_registration).await }, + ); + + // Send the join handle back to the spawner. If sending fails, + // we assume the parent task was canceled, so cancel this task + // as well. + if let Err(join_handle) = sender.send(join_handle) { + join_handle.abort() + } + }); + + // Send the callback to the LocalSet task + if let Err(e) = worker_spawner.send(spawn_task) { + // Propagate the error as a panic in the join handle. + panic!("Failed to send job to worker: {}", e); + } + + // Wait for the task's join handle + let join_handle = match receiver.await { + Ok(handle) => handle, + Err(e) => { + // We sent the task successfully, but failed to get its + // join handle... We assume something happened to the worker + // and the task was not spawned. Propagate the error as a + // panic in the join handle. + panic!("Worker failed to send join handle: {}", e); + } + }; + + // Wait for the task to complete + let join_result = join_handle.await; + + match join_result { + Ok(Ok(output)) => output, + Ok(Err(_)) => { + // Pinned task was aborted. But that only happens if this + // task is aborted. So this is an impossible branch. + unreachable!( + "Reaching this branch means this task was previously \ + aborted but it continued running anyways" + ) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else if e.is_cancelled() { + // No one else should have the join handle, so this is + // unexpected. Forward this error as a panic in the join + // handle. + panic!("spawn_pinned task was canceled: {}", e); + } else { + // Something unknown happened (not a panic or + // cancellation). Forward this error as a panic in the + // join handle. + panic!("spawn_pinned task failed: {}", e); + } + } + } + }) + } + + /// Find the worker with the least number of tasks, increment its task + /// count, and return its handle. Make sure to actually spawn a task on + /// the worker so the task count is kept consistent with load. + /// + /// A job count guard is also returned to ensure the task count gets + /// decremented when the job is done. + fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) { + loop { + let (worker, task_count) = self + .workers + .iter() + .map(|worker| (worker, worker.task_count.load(Ordering::SeqCst))) + .min_by_key(|&(_, count)| count) + .expect("There must be more than one worker"); + + // Make sure the task count hasn't changed since when we choose this + // worker. Otherwise, restart the search. + if worker + .task_count + .compare_exchange( + task_count, + task_count + 1, + Ordering::SeqCst, + Ordering::Relaxed, + ) + .is_ok() + { + return (worker, JobCountGuard(Arc::clone(&worker.task_count))); + } + } + } + + #[track_caller] + fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) { + let worker = &self.workers[idx]; + worker.task_count.fetch_add(1, Ordering::SeqCst); + + (worker, JobCountGuard(Arc::clone(&worker.task_count))) + } +} + +/// Automatically decrements a worker's job count when a job finishes (when +/// this gets dropped). +struct JobCountGuard(Arc); + +impl Drop for JobCountGuard { + fn drop(&mut self) { + // Decrement the job count + let previous_value = self.0.fetch_sub(1, Ordering::SeqCst); + debug_assert!(previous_value >= 1); + } +} + +/// Calls abort on the handle when dropped. +struct AbortGuard(AbortHandle); + +impl Drop for AbortGuard { + fn drop(&mut self) { + self.0.abort(); + } +} + +type PinnedFutureSpawner = Box; + +struct LocalWorkerHandle { + runtime_handle: tokio::runtime::Handle, + spawner: UnboundedSender, + task_count: Arc, +} + +impl LocalWorkerHandle { + /// Create a new worker for executing pinned tasks + fn new_worker() -> LocalWorkerHandle { + let (sender, receiver) = unbounded_channel(); + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to start a pinned worker thread runtime"); + let runtime_handle = runtime.handle().clone(); + let task_count = Arc::new(AtomicUsize::new(0)); + let task_count_clone = Arc::clone(&task_count); + + std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone)); + + LocalWorkerHandle { + runtime_handle, + spawner: sender, + task_count, + } + } + + fn run( + runtime: tokio::runtime::Runtime, + mut task_receiver: UnboundedReceiver, + task_count: Arc, + ) { + let local_set = LocalSet::new(); + local_set.block_on(&runtime, async { + while let Some(spawn_task) = task_receiver.recv().await { + // Calls spawn_local(future) + (spawn_task)(); + } + }); + + // If there are any tasks on the runtime associated with a LocalSet task + // that has already completed, but whose output has not yet been + // reported, let that task complete. + // + // Since the task_count is decremented when the runtime task exits, + // reading that counter lets us know if any such tasks completed during + // the call to `block_on`. + // + // Tasks on the LocalSet can't complete during this loop since they're + // stored on the LocalSet and we aren't accessing it. + let mut previous_task_count = task_count.load(Ordering::SeqCst); + loop { + // This call will also run tasks spawned on the runtime. + runtime.block_on(tokio::task::yield_now()); + let new_task_count = task_count.load(Ordering::SeqCst); + if new_task_count == previous_task_count { + break; + } else { + previous_task_count = new_task_count; + } + } + + // It's now no longer possible for a task on the runtime to be + // associated with a LocalSet task that has completed. Drop both the + // LocalSet and runtime to let tasks on the runtime be cancelled if and + // only if they are still on the LocalSet. + // + // Drop the LocalSet task first so that anyone awaiting the runtime + // JoinHandle will see the cancelled error after the LocalSet task + // destructor has completed. + drop(local_set); + drop(runtime); + } +} diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs new file mode 100644 index 0000000..ee66adb --- /dev/null +++ b/src/time/delay_queue.rs @@ -0,0 +1,1271 @@ +//! A queue of delayed elements. +//! +//! See [`DelayQueue`] for more details. +//! +//! [`DelayQueue`]: struct@DelayQueue + +use crate::time::wheel::{self, Wheel}; + +use futures_core::ready; +use tokio::time::{sleep_until, Duration, Instant, Sleep}; + +use core::ops::{Index, IndexMut}; +use slab::Slab; +use std::cmp; +use std::collections::HashMap; +use std::convert::From; +use std::fmt; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{self, Poll, Waker}; + +/// A queue of delayed elements. +/// +/// Once an element is inserted into the `DelayQueue`, it is yielded once the +/// specified deadline has been reached. +/// +/// # Usage +/// +/// Elements are inserted into `DelayQueue` using the [`insert`] or +/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is +/// returned. The key is used to remove the entry or to change the deadline at +/// which it should be yielded back. +/// +/// Once delays have been configured, the `DelayQueue` is used via its +/// [`Stream`] implementation. [`poll_expired`] is called. If an entry has reached its +/// deadline, it is returned. If not, `Poll::Pending` is returned indicating that the +/// current task will be notified once the deadline has been reached. +/// +/// # `Stream` implementation +/// +/// Items are retrieved from the queue via [`DelayQueue::poll_expired`]. If no delays have +/// expired, no items are returned. In this case, `Poll::Pending` is returned and the +/// current task is registered to be notified once the next item's delay has +/// expired. +/// +/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll` +/// returns `Poll::Ready(None)`. This indicates that the stream has reached an end. +/// However, if a new item is inserted *after*, `poll` will once again start +/// returning items or `Poll::Pending`. +/// +/// Items are returned ordered by their expirations. Items that are configured +/// to expire first will be returned first. There are no ordering guarantees +/// for items configured to expire at the same instant. Also note that delays are +/// rounded to the closest millisecond. +/// +/// # Implementation +/// +/// The [`DelayQueue`] is backed by a separate instance of a timer wheel similar to that used internally +/// by Tokio's standalone timer utilities such as [`sleep`]. Because of this, it offers the same +/// performance and scalability benefits. +/// +/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation, +/// and allows reuse of the memory allocated for expired entires. +/// +/// Capacity can be checked using [`capacity`] and allocated preemptively by using +/// the [`reserve`] method. +/// +/// # Usage +/// +/// Using `DelayQueue` to manage cache entries. +/// +/// ```rust,no_run +/// use tokio_util::time::{DelayQueue, delay_queue}; +/// +/// use futures::ready; +/// use std::collections::HashMap; +/// use std::task::{Context, Poll}; +/// use std::time::Duration; +/// # type CacheKey = String; +/// # type Value = String; +/// +/// struct Cache { +/// entries: HashMap, +/// expirations: DelayQueue, +/// } +/// +/// const TTL_SECS: u64 = 30; +/// +/// impl Cache { +/// fn insert(&mut self, key: CacheKey, value: Value) { +/// let delay = self.expirations +/// .insert(key.clone(), Duration::from_secs(TTL_SECS)); +/// +/// self.entries.insert(key, (value, delay)); +/// } +/// +/// fn get(&self, key: &CacheKey) -> Option<&Value> { +/// self.entries.get(key) +/// .map(|&(ref v, _)| v) +/// } +/// +/// fn remove(&mut self, key: &CacheKey) { +/// if let Some((_, cache_key)) = self.entries.remove(key) { +/// self.expirations.remove(&cache_key); +/// } +/// } +/// +/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<()> { +/// while let Some(entry) = ready!(self.expirations.poll_expired(cx)) { +/// self.entries.remove(entry.get_ref()); +/// } +/// +/// Poll::Ready(()) +/// } +/// } +/// ``` +/// +/// [`insert`]: method@Self::insert +/// [`insert_at`]: method@Self::insert_at +/// [`Key`]: struct@Key +/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html +/// [`poll_expired`]: method@Self::poll_expired +/// [`Stream::poll_expired`]: method@Self::poll_expired +/// [`DelayQueue`]: struct@DelayQueue +/// [`sleep`]: fn@tokio::time::sleep +/// [`slab`]: slab +/// [`capacity`]: method@Self::capacity +/// [`reserve`]: method@Self::reserve +#[derive(Debug)] +pub struct DelayQueue { + /// Stores data associated with entries + slab: SlabStorage, + + /// Lookup structure tracking all delays in the queue + wheel: Wheel>, + + /// Delays that were inserted when already expired. These cannot be stored + /// in the wheel + expired: Stack, + + /// Delay expiring when the *first* item in the queue expires + delay: Option>>, + + /// Wheel polling state + wheel_now: u64, + + /// Instant at which the timer starts + start: Instant, + + /// Waker that is invoked when we potentially need to reset the timer. + /// Because we lazily create the timer when the first entry is created, we + /// need to awaken any poller that polled us before that point. + waker: Option, +} + +#[derive(Default)] +struct SlabStorage { + inner: Slab>, + + // A `compact` call requires a re-mapping of the `Key`s that were changed + // during the `compact` call of the `slab`. Since the keys that were given out + // cannot be changed retroactively we need to keep track of these re-mappings. + // The keys of `key_map` correspond to the old keys that were given out and + // the values to the `Key`s that were re-mapped by the `compact` call. + key_map: HashMap, + + // Index used to create new keys to hand out. + next_key_index: usize, + + // Whether `compact` has been called, necessary in order to decide whether + // to include keys in `key_map`. + compact_called: bool, +} + +impl SlabStorage { + pub(crate) fn with_capacity(capacity: usize) -> SlabStorage { + SlabStorage { + inner: Slab::with_capacity(capacity), + key_map: HashMap::new(), + next_key_index: 0, + compact_called: false, + } + } + + // Inserts data into the inner slab and re-maps keys if necessary + pub(crate) fn insert(&mut self, val: Data) -> Key { + let mut key = KeyInternal::new(self.inner.insert(val)); + let key_contained = self.key_map.contains_key(&key.into()); + + if key_contained { + // It's possible that a `compact` call creates capacity in `self.inner` in + // such a way that a `self.inner.insert` call creates a `key` which was + // previously given out during an `insert` call prior to the `compact` call. + // If `key` is contained in `self.key_map`, we have encountered this exact situation, + // We need to create a new key `key_to_give_out` and include the relation + // `key_to_give_out` -> `key` in `self.key_map`. + let key_to_give_out = self.create_new_key(); + assert!(!self.key_map.contains_key(&key_to_give_out.into())); + self.key_map.insert(key_to_give_out.into(), key); + key = key_to_give_out; + } else if self.compact_called { + // Include an identity mapping in `self.key_map` in order to allow us to + // panic if a key that was handed out is removed more than once. + self.key_map.insert(key.into(), key); + } + + key.into() + } + + // Re-map the key in case compact was previously called. + // Note: Since we include identity mappings in key_map after compact was called, + // we have information about all keys that were handed out. In the case in which + // compact was called and we try to remove a Key that was previously removed + // we can detect invalid keys if no key is found in `key_map`. This is necessary + // in order to prevent situations in which a previously removed key + // corresponds to a re-mapped key internally and which would then be incorrectly + // removed from the slab. + // + // Example to illuminate this problem: + // + // Let's assume our `key_map` is {1 -> 2, 2 -> 1} and we call remove(1). If we + // were to remove 1 again, we would not find it inside `key_map` anymore. + // If we were to imply from this that no re-mapping was necessary, we would + // incorrectly remove 1 from `self.slab.inner`, which corresponds to the + // handed-out key 2. + pub(crate) fn remove(&mut self, key: &Key) -> Data { + let remapped_key = if self.compact_called { + match self.key_map.remove(key) { + Some(key_internal) => key_internal, + None => panic!("invalid key"), + } + } else { + (*key).into() + }; + + self.inner.remove(remapped_key.index) + } + + pub(crate) fn shrink_to_fit(&mut self) { + self.inner.shrink_to_fit(); + self.key_map.shrink_to_fit(); + } + + pub(crate) fn compact(&mut self) { + if !self.compact_called { + for (key, _) in self.inner.iter() { + self.key_map.insert(Key::new(key), KeyInternal::new(key)); + } + } + + let mut remapping = HashMap::new(); + self.inner.compact(|_, from, to| { + remapping.insert(from, to); + true + }); + + // At this point `key_map` contains a mapping for every element. + for internal_key in self.key_map.values_mut() { + if let Some(new_internal_key) = remapping.get(&internal_key.index) { + *internal_key = KeyInternal::new(*new_internal_key); + } + } + + if self.key_map.capacity() > 2 * self.key_map.len() { + self.key_map.shrink_to_fit(); + } + + self.compact_called = true; + } + + // Tries to re-map a `Key` that was given out to the user to its + // corresponding internal key. + fn remap_key(&self, key: &Key) -> Option { + let key_map = &self.key_map; + if self.compact_called { + key_map.get(key).copied() + } else { + Some((*key).into()) + } + } + + fn create_new_key(&mut self) -> KeyInternal { + while self.key_map.contains_key(&Key::new(self.next_key_index)) { + self.next_key_index = self.next_key_index.wrapping_add(1); + } + + KeyInternal::new(self.next_key_index) + } + + pub(crate) fn len(&self) -> usize { + self.inner.len() + } + + pub(crate) fn capacity(&self) -> usize { + self.inner.capacity() + } + + pub(crate) fn clear(&mut self) { + self.inner.clear(); + self.key_map.clear(); + self.compact_called = false; + } + + pub(crate) fn reserve(&mut self, additional: usize) { + self.inner.reserve(additional); + + if self.compact_called { + self.key_map.reserve(additional); + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub(crate) fn contains(&self, key: &Key) -> bool { + let remapped_key = self.remap_key(key); + + match remapped_key { + Some(internal_key) => self.inner.contains(internal_key.index), + None => false, + } + } +} + +impl fmt::Debug for SlabStorage +where + T: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + if fmt.alternate() { + fmt.debug_map().entries(self.inner.iter()).finish() + } else { + fmt.debug_struct("Slab") + .field("len", &self.len()) + .field("cap", &self.capacity()) + .finish() + } + } +} + +impl Index for SlabStorage { + type Output = Data; + + fn index(&self, key: Key) -> &Self::Output { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +impl IndexMut for SlabStorage { + fn index_mut(&mut self, key: Key) -> &mut Data { + let remapped_key = self.remap_key(&key); + + match remapped_key { + Some(internal_key) => &mut self.inner[internal_key.index], + None => panic!("Invalid index {}", key.index), + } + } +} + +/// An entry in `DelayQueue` that has expired and been removed. +/// +/// Values are returned by [`DelayQueue::poll_expired`]. +/// +/// [`DelayQueue::poll_expired`]: method@DelayQueue::poll_expired +#[derive(Debug)] +pub struct Expired { + /// The data stored in the queue + data: T, + + /// The expiration time + deadline: Instant, + + /// The key associated with the entry + key: Key, +} + +/// Token to a value stored in a `DelayQueue`. +/// +/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`] +/// documentation for more details. +/// +/// [`DelayQueue`]: struct@DelayQueue +/// [`DelayQueue::insert`]: method@DelayQueue::insert +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Key { + index: usize, +} + +// Whereas `Key` is given out to users that use `DelayQueue`, internally we use +// `KeyInternal` as the key type in order to make the logic of mapping between keys +// as a result of `compact` calls clearer. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct KeyInternal { + index: usize, +} + +#[derive(Debug)] +struct Stack { + /// Head of the stack + head: Option, + _p: PhantomData T>, +} + +#[derive(Debug)] +struct Data { + /// The data being stored in the queue and will be returned at the requested + /// instant. + inner: T, + + /// The instant at which the item is returned. + when: u64, + + /// Set to true when stored in the `expired` queue + expired: bool, + + /// Next entry in the stack + next: Option, + + /// Previous entry in the stack + prev: Option, +} + +/// Maximum number of entries the queue can handle +const MAX_ENTRIES: usize = (1 << 30) - 1; + +impl DelayQueue { + /// Creates a new, empty, `DelayQueue`. + /// + /// The queue will not allocate storage until items are inserted into it. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// let delay_queue: DelayQueue = DelayQueue::new(); + /// ``` + pub fn new() -> DelayQueue { + DelayQueue::with_capacity(0) + } + + /// Creates a new, empty, `DelayQueue` with the specified capacity. + /// + /// The queue will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the queue will not allocate for + /// storage. + /// + /// # Examples + /// + /// ```rust + /// # use tokio_util::time::DelayQueue; + /// # use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// // These insertions are done without further allocation + /// for i in 0..10 { + /// delay_queue.insert(i, Duration::from_secs(i)); + /// } + /// + /// // This will make the queue allocate additional storage + /// delay_queue.insert(11, Duration::from_secs(11)); + /// # } + /// ``` + pub fn with_capacity(capacity: usize) -> DelayQueue { + DelayQueue { + wheel: Wheel::new(), + slab: SlabStorage::with_capacity(capacity), + expired: Stack::default(), + delay: None, + wheel_now: 0, + start: Instant::now(), + waker: None, + } + } + + /// Inserts `value` into the queue set to expire at a specific instant in + /// time. + /// + /// This function is identical to `insert`, but takes an `Instant` instead + /// of a `Duration`. + /// + /// `value` is stored in the queue until `when` is reached. At which point, + /// `value` will be returned from [`poll_expired`]. If `when` has already been + /// reached, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an argument to + /// [`remove`] and [`reset`]. Note that [`Key`] is a token and is reused once + /// `value` is removed from the queue either by calling [`poll_expired`] after + /// `when` is reached or by calling [`remove`]. At this point, the caller + /// must take care to not use the returned [`Key`] again as it may reference + /// a different item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert_at( + /// "foo", Instant::now() + Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + #[track_caller] + pub fn insert_at(&mut self, value: T, when: Instant) -> Key { + assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded"); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + // Insert the value in the store + let key = self.slab.insert(Data { + inner: value, + when, + expired: false, + next: None, + prev: None, + }); + + self.insert_idx(when, key); + + // Set a new delay if the current's deadline is later than the one of the new item + let should_set_delay = if let Some(ref delay) = self.delay { + let current_exp = self.normalize_deadline(delay.deadline()); + current_exp > when + } else { + true + }; + + if should_set_delay { + if let Some(waker) = self.waker.take() { + waker.wake(); + } + + let delay_time = self.start + Duration::from_millis(when); + if let Some(ref mut delay) = &mut self.delay { + delay.as_mut().reset(delay_time); + } else { + self.delay = Some(Box::pin(sleep_until(delay_time))); + } + } + + key + } + + /// Attempts to pull out the next value of the delay queue, registering the + /// current task for wakeup if the value is not yet available, and returning + /// `None` if the queue is exhausted. + pub fn poll_expired(&mut self, cx: &mut task::Context<'_>) -> Poll>> { + if !self + .waker + .as_ref() + .map(|w| w.will_wake(cx.waker())) + .unwrap_or(false) + { + self.waker = Some(cx.waker().clone()); + } + + let item = ready!(self.poll_idx(cx)); + Poll::Ready(item.map(|key| { + let data = self.slab.remove(&key); + debug_assert!(data.next.is_none()); + debug_assert!(data.prev.is_none()); + + Expired { + key, + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + })) + } + + /// Inserts `value` into the queue set to expire after the requested duration + /// elapses. + /// + /// This function is identical to `insert_at`, but takes a `Duration` + /// instead of an `Instant`. + /// + /// `value` is stored in the queue until `timeout` duration has + /// elapsed after `insert` was called. At that point, `value` will + /// be returned from [`poll_expired`]. If `timeout` is a `Duration` of + /// zero, then `value` is immediately made available to poll. + /// + /// The return value represents the insertion and is used as an + /// argument to [`remove`] and [`reset`]. Note that [`Key`] is a + /// token and is reused once `value` is removed from the queue + /// either by calling [`poll_expired`] after `timeout` has elapsed + /// or by calling [`remove`]. At this point, the caller must not + /// use the returned [`Key`] again as it may reference a different + /// item in the queue. + /// + /// See [type] level documentation for more details. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum + /// duration supported by the timer in the current `Runtime`. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + /// + /// [`poll_expired`]: method@Self::poll_expired + /// [`remove`]: method@Self::remove + /// [`reset`]: method@Self::reset + /// [`Key`]: struct@Key + /// [type]: # + #[track_caller] + pub fn insert(&mut self, value: T, timeout: Duration) -> Key { + self.insert_at(value, Instant::now() + timeout) + } + + #[track_caller] + fn insert_idx(&mut self, when: u64, key: Key) { + use self::wheel::{InsertError, Stack}; + + // Register the deadline with the timer wheel + match self.wheel.insert(when, key, &mut self.slab) { + Ok(_) => {} + Err((_, InsertError::Elapsed)) => { + self.slab[key].expired = true; + // The delay is already expired, store it in the expired queue + self.expired.push(key, &mut self.slab); + } + Err((_, err)) => panic!("invalid deadline; err={:?}", err), + } + } + + /// Removes the key from the expired queue or the timer wheel + /// depending on its expiration status. + /// + /// # Panics + /// + /// Panics if the key is not contained in the expired queue or the wheel. + #[track_caller] + fn remove_key(&mut self, key: &Key) { + use crate::time::wheel::Stack; + + // Special case the `expired` queue + if self.slab[*key].expired { + self.expired.remove(key, &mut self.slab); + } else { + self.wheel.remove(key, &mut self.slab); + } + } + + /// Removes the item associated with `key` from the queue. + /// + /// There must be an item associated with `key`. The function returns the + /// removed item as well as the `Instant` at which it will the delay will + /// have expired. + /// + /// # Panics + /// + /// The function panics if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // Remove the entry + /// let item = delay_queue.remove(&key); + /// assert_eq!(*item.get_ref(), "foo"); + /// # } + /// ``` + #[track_caller] + pub fn remove(&mut self, key: &Key) -> Expired { + let prev_deadline = self.next_deadline(); + + self.remove_key(key); + let data = self.slab.remove(key); + + let next_deadline = self.next_deadline(); + if prev_deadline != next_deadline { + match (next_deadline, &mut self.delay) { + (None, _) => self.delay = None, + (Some(deadline), Some(delay)) => delay.as_mut().reset(deadline), + (Some(deadline), None) => self.delay = Some(Box::pin(sleep_until(deadline))), + } + } + + Expired { + key: Key::new(key.index), + data: data.inner, + deadline: self.start + Duration::from_millis(data.when), + } + } + + /// Attempts to remove the item associated with `key` from the queue. + /// + /// Removes the item associated with `key`, and returns it along with the + /// `Instant` at which it would have expired, if it exists. + /// + /// Returns `None` if `key` is not in the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // The item is in the queue, `try_remove` returns `Some(Expired("foo"))`. + /// let item = delay_queue.try_remove(&key); + /// assert_eq!(item.unwrap().into_inner(), "foo"); + /// + /// // The item is not in the queue anymore, `try_remove` returns `None`. + /// let item = delay_queue.try_remove(&key); + /// assert!(item.is_none()); + /// # } + /// ``` + pub fn try_remove(&mut self, key: &Key) -> Option> { + if self.slab.contains(key) { + Some(self.remove(key)) + } else { + None + } + } + + /// Sets the delay of the item associated with `key` to expire at `when`. + /// + /// This function is identical to `reset` but takes an `Instant` instead of + /// a `Duration`. + /// + /// The item remains in the queue but the delay is set to expire at `when`. + /// If `when` is in the past, then the item is immediately made available to + /// the caller. + /// + /// # Panics + /// + /// This function panics if `when` is too far in the future or if `key` is + /// not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio::time::{Duration, Instant}; + /// use tokio_util::time::DelayQueue; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10)); + /// + /// // "foo" is now scheduled to be returned in 10 seconds + /// # } + /// ``` + #[track_caller] + pub fn reset_at(&mut self, key: &Key, when: Instant) { + self.remove_key(key); + + // Normalize the deadline. Values cannot be set to expire in the past. + let when = self.normalize_deadline(when); + + self.slab[*key].when = when; + self.slab[*key].expired = false; + + self.insert_idx(when, *key); + + let next_deadline = self.next_deadline(); + if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) { + // This should awaken us if necessary (ie, if already expired) + delay.as_mut().reset(deadline); + } + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation. + /// This function is not guaranteed to, and in most cases, won't decrease the capacity of the slab + /// to the number of elements still contained in it, because elements cannot be moved to a different + /// index. To decrease the capacity to the size of the slab use [`compact`]. + /// + /// This function can take O(n) time even when the capacity cannot be reduced or the allocation is + /// shrunk in place. Repeated calls run in O(1) though. + /// + /// [`compact`]: method@Self::compact + pub fn shrink_to_fit(&mut self) { + self.slab.shrink_to_fit(); + } + + /// Shrink the capacity of the slab, which `DelayQueue` uses internally for storage allocation, + /// to the number of elements that are contained in it. + /// + /// This methods runs in O(n). + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::with_capacity(10); + /// + /// let key1 = delay_queue.insert(5, Duration::from_secs(5)); + /// let key2 = delay_queue.insert(10, Duration::from_secs(10)); + /// let key3 = delay_queue.insert(15, Duration::from_secs(15)); + /// + /// delay_queue.remove(&key2); + /// + /// delay_queue.compact(); + /// assert_eq!(delay_queue.capacity(), 2); + /// # } + /// ``` + pub fn compact(&mut self) { + self.slab.compact(); + } + + /// Returns the next time to poll as determined by the wheel + fn next_deadline(&mut self) -> Option { + self.wheel + .poll_at() + .map(|poll_at| self.start + Duration::from_millis(poll_at)) + } + + /// Sets the delay of the item associated with `key` to expire after + /// `timeout`. + /// + /// This function is identical to `reset_at` but takes a `Duration` instead + /// of an `Instant`. + /// + /// The item remains in the queue but the delay is set to expire after + /// `timeout`. If `timeout` is zero, then the item is immediately made + /// available to the caller. + /// + /// # Panics + /// + /// This function panics if `timeout` is greater than the maximum supported + /// duration or if `key` is not contained by the queue. + /// + /// # Examples + /// + /// Basic usage + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// let key = delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// // "foo" is scheduled to be returned in 5 seconds + /// + /// delay_queue.reset(&key, Duration::from_secs(10)); + /// + /// // "foo"is now scheduled to be returned in 10 seconds + /// # } + /// ``` + #[track_caller] + pub fn reset(&mut self, key: &Key, timeout: Duration) { + self.reset_at(key, Instant::now() + timeout); + } + + /// Clears the queue, removing all items. + /// + /// After calling `clear`, [`poll_expired`] will return `Ok(Ready(None))`. + /// + /// Note that this method has no effect on the allocated capacity. + /// + /// [`poll_expired`]: method@Self::poll_expired + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("foo", Duration::from_secs(5)); + /// + /// assert!(!delay_queue.is_empty()); + /// + /// delay_queue.clear(); + /// + /// assert!(delay_queue.is_empty()); + /// # } + /// ``` + pub fn clear(&mut self) { + self.slab.clear(); + self.expired = Stack::default(); + self.wheel = Wheel::new(); + self.delay = None; + } + + /// Returns the number of elements the queue can hold without reallocating. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// + /// let delay_queue: DelayQueue = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.capacity(), 10); + /// ``` + pub fn capacity(&self) -> usize { + self.slab.capacity() + } + + /// Returns the number of elements currently in the queue. + /// + /// # Examples + /// + /// ```rust + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue: DelayQueue = DelayQueue::with_capacity(10); + /// assert_eq!(delay_queue.len(), 0); + /// delay_queue.insert(3, Duration::from_secs(5)); + /// assert_eq!(delay_queue.len(), 1); + /// # } + /// ``` + pub fn len(&self) -> usize { + self.slab.len() + } + + /// Reserves capacity for at least `additional` more items to be queued + /// without allocating. + /// + /// `reserve` does nothing if the queue already has sufficient capacity for + /// `additional` more values. If more capacity is required, a new segment of + /// memory will be allocated and all existing values will be copied into it. + /// As such, if the queue is already very large, a call to `reserve` can end + /// up being expensive. + /// + /// The queue may reserve more than `additional` extra space in order to + /// avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity exceeds the maximum number of entries the + /// queue can contain. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// + /// delay_queue.insert("hello", Duration::from_secs(10)); + /// delay_queue.reserve(10); + /// + /// assert!(delay_queue.capacity() >= 11); + /// # } + /// ``` + #[track_caller] + pub fn reserve(&mut self, additional: usize) { + assert!( + self.slab.capacity() + additional <= MAX_ENTRIES, + "max queue capacity exceeded" + ); + self.slab.reserve(additional); + } + + /// Returns `true` if there are no items in the queue. + /// + /// Note that this function returns `false` even if all items have not yet + /// expired and a call to `poll` will return `Poll::Pending`. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::time::DelayQueue; + /// use std::time::Duration; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut delay_queue = DelayQueue::new(); + /// assert!(delay_queue.is_empty()); + /// + /// delay_queue.insert("hello", Duration::from_secs(5)); + /// assert!(!delay_queue.is_empty()); + /// # } + /// ``` + pub fn is_empty(&self) -> bool { + self.slab.is_empty() + } + + /// Polls the queue, returning the index of the next slot in the slab that + /// should be returned. + /// + /// A slot should be returned when the associated deadline has been reached. + fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll> { + use self::wheel::Stack; + + let expired = self.expired.pop(&mut self.slab); + + if expired.is_some() { + return Poll::Ready(expired); + } + + loop { + if let Some(ref mut delay) = self.delay { + if !delay.is_elapsed() { + ready!(Pin::new(&mut *delay).poll(cx)); + } + + let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down); + + self.wheel_now = now; + } + + // We poll the wheel to get the next value out before finding the next deadline. + let wheel_idx = self.wheel.poll(self.wheel_now, &mut self.slab); + + self.delay = self.next_deadline().map(|when| Box::pin(sleep_until(when))); + + if let Some(idx) = wheel_idx { + return Poll::Ready(Some(idx)); + } + + if self.delay.is_none() { + return Poll::Ready(None); + } + } + } + + fn normalize_deadline(&self, when: Instant) -> u64 { + let when = if when < self.start { + 0 + } else { + crate::time::ms(when - self.start, crate::time::Round::Up) + }; + + cmp::max(when, self.wheel.elapsed()) + } +} + +// We never put `T` in a `Pin`... +impl Unpin for DelayQueue {} + +impl Default for DelayQueue { + fn default() -> DelayQueue { + DelayQueue::new() + } +} + +impl futures_core::Stream for DelayQueue { + // DelayQueue seems much more specific, where a user may care that it + // has reached capacity, so return those errors instead of panicking. + type Item = Expired; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + DelayQueue::poll_expired(self.get_mut(), cx) + } +} + +impl wheel::Stack for Stack { + type Owned = Key; + type Borrowed = Key; + type Store = SlabStorage; + + fn is_empty(&self) -> bool { + self.head.is_none() + } + + fn push(&mut self, item: Self::Owned, store: &mut Self::Store) { + // Ensure the entry is not already in a stack. + debug_assert!(store[item].next.is_none()); + debug_assert!(store[item].prev.is_none()); + + // Remove the old head entry + let old = self.head.take(); + + if let Some(idx) = old { + store[idx].prev = Some(item); + } + + store[item].next = old; + self.head = Some(item); + } + + fn pop(&mut self, store: &mut Self::Store) -> Option { + if let Some(key) = self.head { + self.head = store[key].next; + + if let Some(idx) = self.head { + store[idx].prev = None; + } + + store[key].next = None; + debug_assert!(store[key].prev.is_none()); + + Some(key) + } else { + None + } + } + + #[track_caller] + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) { + let key = *item; + assert!(store.contains(item)); + + // Ensure that the entry is in fact contained by the stack + debug_assert!({ + // This walks the full linked list even if an entry is found. + let mut next = self.head; + let mut contains = false; + + while let Some(idx) = next { + let data = &store[idx]; + + if idx == *item { + debug_assert!(!contains); + contains = true; + } + + next = data.next; + } + + contains + }); + + if let Some(next) = store[key].next { + store[next].prev = store[key].prev; + } + + if let Some(prev) = store[key].prev { + store[prev].next = store[key].next; + } else { + self.head = store[key].next; + } + + store[key].next = None; + store[key].prev = None; + } + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 { + store[*item].when + } +} + +impl Default for Stack { + fn default() -> Stack { + Stack { + head: None, + _p: PhantomData, + } + } +} + +impl Key { + pub(crate) fn new(index: usize) -> Key { + Key { index } + } +} + +impl KeyInternal { + pub(crate) fn new(index: usize) -> KeyInternal { + KeyInternal { index } + } +} + +impl From for KeyInternal { + fn from(item: Key) -> Self { + KeyInternal::new(item.index) + } +} + +impl From for Key { + fn from(item: KeyInternal) -> Self { + Key::new(item.index) + } +} + +impl Expired { + /// Returns a reference to the inner value. + pub fn get_ref(&self) -> &T { + &self.data + } + + /// Returns a mutable reference to the inner value. + pub fn get_mut(&mut self) -> &mut T { + &mut self.data + } + + /// Consumes `self` and returns the inner value. + pub fn into_inner(self) -> T { + self.data + } + + /// Returns the deadline that the expiration was set to. + pub fn deadline(&self) -> Instant { + self.deadline + } + + /// Returns the key that the expiration is indexed by. + pub fn key(&self) -> Key { + self.key + } +} diff --git a/src/time/mod.rs b/src/time/mod.rs new file mode 100644 index 0000000..2d34008 --- /dev/null +++ b/src/time/mod.rs @@ -0,0 +1,47 @@ +//! Additional utilities for tracking time. +//! +//! This module provides additional utilities for executing code after a set period +//! of time. Currently there is only one: +//! +//! * `DelayQueue`: A queue where items are returned once the requested delay +//! has expired. +//! +//! This type must be used from within the context of the `Runtime`. + +use std::time::Duration; + +mod wheel; + +pub mod delay_queue; + +#[doc(inline)] +pub use delay_queue::DelayQueue; + +// ===== Internal utils ===== + +enum Round { + Up, + Down, +} + +/// Convert a `Duration` to milliseconds, rounding up and saturating at +/// `u64::MAX`. +/// +/// The saturating is fine because `u64::MAX` milliseconds are still many +/// million years. +#[inline] +fn ms(duration: Duration, round: Round) -> u64 { + const NANOS_PER_MILLI: u32 = 1_000_000; + const MILLIS_PER_SEC: u64 = 1_000; + + // Round up. + let millis = match round { + Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI, + Round::Down => duration.subsec_millis(), + }; + + duration + .as_secs() + .saturating_mul(MILLIS_PER_SEC) + .saturating_add(u64::from(millis)) +} diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs new file mode 100644 index 0000000..8ea30af --- /dev/null +++ b/src/time/wheel/level.rs @@ -0,0 +1,253 @@ +use crate::time::wheel::Stack; + +use std::fmt; + +/// Wheel for a single level in the timer. This wheel contains 64 slots. +pub(crate) struct Level { + level: usize, + + /// Bit field tracking which slots currently contain entries. + /// + /// Using a bit field to track slots that contain entries allows avoiding a + /// scan to find entries. This field is updated when entries are added or + /// removed from a slot. + /// + /// The least-significant bit represents slot zero. + occupied: u64, + + /// Slots + slot: [T; LEVEL_MULT], +} + +/// Indicates when a slot must be processed next. +#[derive(Debug)] +pub(crate) struct Expiration { + /// The level containing the slot. + pub(crate) level: usize, + + /// The slot index. + pub(crate) slot: usize, + + /// The instant at which the slot needs to be processed. + pub(crate) deadline: u64, +} + +/// Level multiplier. +/// +/// Being a power of 2 is very important. +const LEVEL_MULT: usize = 64; + +impl Level { + pub(crate) fn new(level: usize) -> Level { + // Rust's derived implementations for arrays require that the value + // contained by the array be `Copy`. So, here we have to manually + // initialize every single slot. + macro_rules! s { + () => { + T::default() + }; + } + + Level { + level, + occupied: 0, + slot: [ + // It does not look like the necessary traits are + // derived for [T; 64]. + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + s!(), + ], + } + } + + /// Finds the slot that needs to be processed next and returns the slot and + /// `Instant` at which this slot must be processed. + pub(crate) fn next_expiration(&self, now: u64) -> Option { + // Use the `occupied` bit field to get the index of the next slot that + // needs to be processed. + let slot = match self.next_occupied_slot(now) { + Some(slot) => slot, + None => return None, + }; + + // From the slot index, calculate the `Instant` at which it needs to be + // processed. This value *must* be in the future with respect to `now`. + + let level_range = level_range(self.level); + let slot_range = slot_range(self.level); + + // TODO: This can probably be simplified w/ power of 2 math + let level_start = now - (now % level_range); + let deadline = level_start + slot as u64 * slot_range; + + debug_assert!( + deadline >= now, + "deadline={}; now={}; level={}; slot={}; occupied={:b}", + deadline, + now, + self.level, + slot, + self.occupied + ); + + Some(Expiration { + level: self.level, + slot, + deadline, + }) + } + + fn next_occupied_slot(&self, now: u64) -> Option { + if self.occupied == 0 { + return None; + } + + // Get the slot for now using Maths + let now_slot = (now / slot_range(self.level)) as usize; + let occupied = self.occupied.rotate_right(now_slot as u32); + let zeros = occupied.trailing_zeros() as usize; + let slot = (zeros + now_slot) % 64; + + Some(slot) + } + + pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].push(item, store); + self.occupied |= occupied_bit(slot); + } + + pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) { + let slot = slot_for(when, self.level); + + self.slot[slot].remove(item, store); + + if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + // Unset the bit + self.occupied ^= occupied_bit(slot); + } + } + + pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option { + let ret = self.slot[slot].pop(store); + + if ret.is_some() && self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + self.occupied ^= occupied_bit(slot); + } + + ret + } +} + +impl fmt::Debug for Level { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Level") + .field("occupied", &self.occupied) + .finish() + } +} + +fn occupied_bit(slot: usize) -> u64 { + 1 << slot +} + +fn slot_range(level: usize) -> u64 { + LEVEL_MULT.pow(level as u32) as u64 +} + +fn level_range(level: usize) -> u64 { + LEVEL_MULT as u64 * slot_range(level) +} + +/// Convert a duration (milliseconds) and a level to a slot position +fn slot_for(duration: u64, level: usize) -> usize { + ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_slot_for() { + for pos in 0..64 { + assert_eq!(pos as usize, slot_for(pos, 0)); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!(pos as usize, slot_for(a as u64, level)); + } + } + } +} diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs new file mode 100644 index 0000000..ffa05ab --- /dev/null +++ b/src/time/wheel/mod.rs @@ -0,0 +1,315 @@ +mod level; +pub(crate) use self::level::Expiration; +use self::level::Level; + +mod stack; +pub(crate) use self::stack::Stack; + +use std::borrow::Borrow; +use std::fmt::Debug; +use std::usize; + +/// Timing wheel implementation. +/// +/// This type provides the hashed timing wheel implementation that backs `Timer` +/// and `DelayQueue`. +/// +/// The structure is generic over `T: Stack`. This allows handling timeout data +/// being stored on the heap or in a slab. In order to support the latter case, +/// the slab must be passed into each function allowing the implementation to +/// lookup timer entries. +/// +/// See `Timer` documentation for some implementation notes. +#[derive(Debug)] +pub(crate) struct Wheel { + /// The number of milliseconds elapsed since the wheel started. + elapsed: u64, + + /// Timer wheel. + /// + /// Levels: + /// + /// * 1 ms slots / 64 ms range + /// * 64 ms slots / ~ 4 sec range + /// * ~ 4 sec slots / ~ 4 min range + /// * ~ 4 min slots / ~ 4 hr range + /// * ~ 4 hr slots / ~ 12 day range + /// * ~ 12 day slots / ~ 2 yr range + levels: Vec>, +} + +/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots +/// each, the timer is able to track time up to 2 years into the future with a +/// precision of 1 millisecond. +const NUM_LEVELS: usize = 6; + +/// The maximum duration of a delay +const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; + +#[derive(Debug)] +pub(crate) enum InsertError { + Elapsed, + Invalid, +} + +impl Wheel +where + T: Stack, +{ + /// Create a new timing wheel + pub(crate) fn new() -> Wheel { + let levels = (0..NUM_LEVELS).map(Level::new).collect(); + + Wheel { elapsed: 0, levels } + } + + /// Return the number of milliseconds that have elapsed since the timing + /// wheel's creation. + pub(crate) fn elapsed(&self) -> u64 { + self.elapsed + } + + /// Insert an entry into the timing wheel. + /// + /// # Arguments + /// + /// * `when`: is the instant at which the entry should be fired. It is + /// represented as the number of milliseconds since the creation + /// of the timing wheel. + /// + /// * `item`: The item to insert into the wheel. + /// + /// * `store`: The slab or `()` when using heap storage. + /// + /// # Return + /// + /// Returns `Ok` when the item is successfully inserted, `Err` otherwise. + /// + /// `Err(Elapsed)` indicates that `when` represents an instant that has + /// already passed. In this case, the caller should fire the timeout + /// immediately. + /// + /// `Err(Invalid)` indicates an invalid `when` argument as been supplied. + pub(crate) fn insert( + &mut self, + when: u64, + item: T::Owned, + store: &mut T::Store, + ) -> Result<(), (T::Owned, InsertError)> { + if when <= self.elapsed { + return Err((item, InsertError::Elapsed)); + } else if when - self.elapsed > MAX_DURATION { + return Err((item, InsertError::Invalid)); + } + + // Get the level at which the entry should be stored + let level = self.level_for(when); + + self.levels[level].add_entry(when, item, store); + + debug_assert!({ + self.levels[level] + .next_expiration(self.elapsed) + .map(|e| e.deadline >= self.elapsed) + .unwrap_or(true) + }); + + Ok(()) + } + + /// Remove `item` from the timing wheel. + #[track_caller] + pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) { + let when = T::when(item, store); + + assert!( + self.elapsed <= when, + "elapsed={}; when={}", + self.elapsed, + when + ); + + let level = self.level_for(when); + + self.levels[level].remove_entry(when, item, store); + } + + /// Instant at which to poll + pub(crate) fn poll_at(&self) -> Option { + self.next_expiration().map(|expiration| expiration.deadline) + } + + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option { + loop { + let expiration = self.next_expiration().and_then(|expiration| { + if expiration.deadline > now { + None + } else { + Some(expiration) + } + }); + + match expiration { + Some(ref expiration) => { + if let Some(item) = self.poll_expiration(expiration, store) { + return Some(item); + } + + self.set_elapsed(expiration.deadline); + } + None => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + return None; + } + } + } + } + + /// Returns the instant at which the next timeout expires. + fn next_expiration(&self) -> Option { + // Check all levels + for level in 0..NUM_LEVELS { + if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) { + // There cannot be any expirations at a higher level that happen + // before this one. + debug_assert!(self.no_expirations_before(level + 1, expiration.deadline)); + + return Some(expiration); + } + } + + None + } + + /// Used for debug assertions + fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { + let mut res = true; + + for l2 in start_level..NUM_LEVELS { + if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) { + if e2.deadline < before { + res = false; + } + } + } + + res + } + + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// return it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn poll_expiration( + &mut self, + expiration: &Expiration, + store: &mut T::Store, + ) -> Option { + while let Some(item) = self.pop_entry(expiration, store) { + if expiration.level == 0 { + debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline); + + return Some(item); + } else { + let when = T::when(item.borrow(), store); + + let next_level = expiration.level - 1; + + self.levels[next_level].add_entry(when, item, store); + } + } + + None + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } + } + + fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option { + self.levels[expiration.level].pop_entry_slot(expiration.slot, store) + } + + fn level_for(&self, when: u64) -> usize { + level_for(self.elapsed, when) + } +} + +fn level_for(elapsed: u64, when: u64) -> usize { + const SLOT_MASK: u64 = (1 << 6) - 1; + + // Mask in the trailing bits ignored by the level calculation in order to cap + // the possible leading zeros + let masked = elapsed ^ when | SLOT_MASK; + + let leading_zeros = masked.leading_zeros() as usize; + let significant = 63 - leading_zeros; + significant / 6 +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_level_for() { + for pos in 0..64 { + assert_eq!( + 0, + level_for(0, pos), + "level_for({}) -- binary = {:b}", + pos, + pos + ); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + + if pos > level { + let a = a - 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + + if pos < 64 { + let a = a + 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({}) -- binary = {:b}", + a, + a + ); + } + } + } + } +} diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs new file mode 100644 index 0000000..c87adca --- /dev/null +++ b/src/time/wheel/stack.rs @@ -0,0 +1,28 @@ +use std::borrow::Borrow; +use std::cmp::Eq; +use std::hash::Hash; + +/// Abstracts the stack operations needed to track timeouts. +pub(crate) trait Stack: Default { + /// Type of the item stored in the stack + type Owned: Borrow; + + /// Borrowed item + type Borrowed: Eq + Hash; + + /// Item storage, this allows a slab to be used instead of just the heap + type Store; + + /// Returns `true` if the stack is empty + fn is_empty(&self) -> bool; + + /// Push an item onto the stack + fn push(&mut self, item: Self::Owned, store: &mut Self::Store); + + /// Pop an item from the stack + fn pop(&mut self, store: &mut Self::Store) -> Option; + + fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store); + + fn when(item: &Self::Borrowed, store: &Self::Store) -> u64; +} diff --git a/src/udp/frame.rs b/src/udp/frame.rs new file mode 100644 index 0000000..d094c04 --- /dev/null +++ b/src/udp/frame.rs @@ -0,0 +1,249 @@ +use crate::codec::{Decoder, Encoder}; + +use futures_core::Stream; +use tokio::{io::ReadBuf, net::UdpSocket}; + +use bytes::{BufMut, BytesMut}; +use futures_core::ready; +use futures_sink::Sink; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{ + borrow::Borrow, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; +use std::{io, mem::MaybeUninit}; + +/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using +/// the `Encoder` and `Decoder` traits to encode and decode frames. +/// +/// Raw UDP sockets work with datagrams, but higher-level code usually wants to +/// batch these into meaningful chunks, called "frames". This method layers +/// framing on top of this socket by using the `Encoder` and `Decoder` traits to +/// handle encoding and decoding of messages frames. Note that the incoming and +/// outgoing frame types may be distinct. +/// +/// This function returns a *single* object that is both [`Stream`] and [`Sink`]; +/// grouping this into a single object is often useful for layering things which +/// require both read and write access to the underlying object. +/// +/// If you want to work more directly with the streams and sink, consider +/// calling [`split`] on the `UdpFramed` returned by this method, which will break +/// them into separate objects, allowing them to interact more easily. +/// +/// [`Stream`]: futures_core::Stream +/// [`Sink`]: futures_sink::Sink +/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split +#[must_use = "sinks do nothing unless polled"] +#[derive(Debug)] +pub struct UdpFramed { + socket: T, + codec: C, + rd: BytesMut, + wr: BytesMut, + out_addr: SocketAddr, + flushed: bool, + is_readable: bool, + current_addr: Option, +} + +const INITIAL_RD_CAPACITY: usize = 64 * 1024; +const INITIAL_WR_CAPACITY: usize = 8 * 1024; + +impl Unpin for UdpFramed {} + +impl Stream for UdpFramed +where + T: Borrow, + C: Decoder, +{ + type Item = Result<(C::Item, SocketAddr), C::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pin = self.get_mut(); + + pin.rd.reserve(INITIAL_RD_CAPACITY); + + loop { + // Are there still bytes left in the read buffer to decode? + if pin.is_readable { + if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { + let current_addr = pin + .current_addr + .expect("will always be set before this line is called"); + + return Poll::Ready(Some(Ok((frame, current_addr)))); + } + + // if this line has been reached then decode has returned `None`. + pin.is_readable = false; + pin.rd.clear(); + } + + // We're out of data. Try and fetch more data to decode + let addr = { + // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a + // transparent wrapper around `[MaybeUninit]`. + let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit]) }; + let mut read = ReadBuf::uninit(buf); + let ptr = read.filled().as_ptr(); + let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); + + assert_eq!(ptr, read.filled().as_ptr()); + let addr = res?; + + // Safety: This is guaranteed to be the number of initialized (and read) bytes due + // to the invariants provided by `ReadBuf::filled`. + unsafe { pin.rd.advance_mut(read.filled().len()) }; + + addr + }; + + pin.current_addr = Some(addr); + pin.is_readable = true; + } + } +} + +impl Sink<(I, SocketAddr)> for UdpFramed +where + T: Borrow, + C: Encoder, +{ + type Error = C::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.flushed { + match self.poll_flush(cx)? { + Poll::Ready(()) => {} + Poll::Pending => return Poll::Pending, + } + } + + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { + let (frame, out_addr) = item; + + let pin = self.get_mut(); + + pin.codec.encode(frame, &mut pin.wr)?; + pin.out_addr = out_addr; + pin.flushed = false; + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.flushed { + return Poll::Ready(Ok(())); + } + + let Self { + ref socket, + ref mut out_addr, + ref mut wr, + .. + } = *self; + + let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; + + let wrote_all = n == self.wr.len(); + self.wr.clear(); + self.flushed = true; + + let res = if wrote_all { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "failed to write entire datagram to socket", + ) + .into()) + }; + + Poll::Ready(res) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} + +impl UdpFramed +where + T: Borrow, +{ + /// Create a new `UdpFramed` backed by the given socket and codec. + /// + /// See struct level documentation for more details. + pub fn new(socket: T, codec: C) -> UdpFramed { + Self { + socket, + codec, + out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), + rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), + wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), + flushed: true, + is_readable: false, + current_addr: None, + } + } + + /// Returns a reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_ref(&self) -> &T { + &self.socket + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. + /// + /// # Note + /// + /// Care should be taken to not tamper with the underlying stream of data + /// coming in as it may corrupt the stream of frames otherwise being worked + /// with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.socket + } + + /// Returns a reference to the underlying codec wrapped by + /// `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &C { + &self.codec + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `UdpFramed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut C { + &mut self.codec + } + + /// Returns a reference to the read buffer. + pub fn read_buffer(&self) -> &BytesMut { + &self.rd + } + + /// Returns a mutable reference to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut BytesMut { + &mut self.rd + } + + /// Consumes the `Framed`, returning its underlying I/O stream. + pub fn into_inner(self) -> T { + self.socket + } +} diff --git a/src/udp/mod.rs b/src/udp/mod.rs new file mode 100644 index 0000000..f88ea03 --- /dev/null +++ b/src/udp/mod.rs @@ -0,0 +1,4 @@ +//! UDP framing + +mod frame; +pub use frame::UdpFramed; diff --git a/tests/_require_full.rs b/tests/_require_full.rs new file mode 100644 index 0000000..045934d --- /dev/null +++ b/tests/_require_full.rs @@ -0,0 +1,2 @@ +#![cfg(not(feature = "full"))] +compile_error!("run tokio-util tests with `--features full`"); diff --git a/tests/codecs.rs b/tests/codecs.rs new file mode 100644 index 0000000..f9a7801 --- /dev/null +++ b/tests/codecs.rs @@ -0,0 +1,464 @@ +#![warn(rust_2018_idioms)] + +use tokio_util::codec::{AnyDelimiterCodec, BytesCodec, Decoder, Encoder, LinesCodec}; + +use bytes::{BufMut, Bytes, BytesMut}; + +#[test] +fn bytes_decoder() { + let mut codec = BytesCodec::new(); + let buf = &mut BytesMut::new(); + buf.put_slice(b"abc"); + assert_eq!("abc", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"a"); + assert_eq!("a", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn bytes_encoder() { + let mut codec = BytesCodec::new(); + + // Default capacity of BytesMut + #[cfg(target_pointer_width = "64")] + const INLINE_CAP: usize = 4 * 8 - 1; + #[cfg(target_pointer_width = "32")] + const INLINE_CAP: usize = 4 * 4 - 1; + + let mut buf = BytesMut::new(); + codec + .encode(Bytes::from_static(&[0; INLINE_CAP + 1]), &mut buf) + .unwrap(); + + // Default capacity of Framed Read + const INITIAL_CAPACITY: usize = 8 * 1024; + + let mut buf = BytesMut::with_capacity(INITIAL_CAPACITY); + codec + .encode(Bytes::from_static(&[0; INITIAL_CAPACITY + 1]), &mut buf) + .unwrap(); + codec + .encode(BytesMut::from(&b"hello"[..]), &mut buf) + .unwrap(); +} + +#[test] +fn lines_decoder() { + let mut codec = LinesCodec::new(); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"line 1\nline 2\r\nline 3\n\r\n\r"); + assert_eq!("line 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("line 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("\rk", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn lines_decoder_max_length() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line 1 is too long\nline 2\nline 3\r\nline 4\n\r\n\r"); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 2", line); + + assert!(codec.decode(buf).is_err()); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("line 4", line); + + let line = codec.decode(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let line = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + line.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + line, + MAX_LENGTH + ); + assert_eq!("\rk", line); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Line that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 6; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"line 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"\n"); + assert_eq!("line 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 10; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn lines_decoder_max_length_newline_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"\nworld"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +// Regression test for [infinite loop bug](https://github.com/tokio-rs/tokio/issues/1483) +#[test] +fn lines_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +// Regression test for [subsequent calls to LinesCodec decode does not return the desired results bug](https://github.com/tokio-rs/tokio/issues/3555) +#[test] +fn lines_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = LinesCodec::new_with_max_length(MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"line "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn lines_encoder() { + let mut codec = LinesCodec::new(); + let mut buf = BytesMut::new(); + + codec.encode("line 1", &mut buf).unwrap(); + assert_eq!("line 1\n", buf); + + codec.encode("line 2", &mut buf).unwrap(); + assert_eq!("line 1\nline 2\n", buf); +} + +#[test] +fn any_delimiters_decoder_any_character() { + let mut codec = AnyDelimiterCodec::new(b",;\n\r".to_vec(), b",".to_vec()); + let buf = &mut BytesMut::new(); + buf.reserve(200); + buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r"); + assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); + assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap()); + assert_eq!("", codec.decode(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!("k", codec.decode_eof(buf).unwrap().unwrap()); + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); +} + +#[test] +fn any_delimiters_decoder_max_length() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk 1 is too long\nchunk 2\nchunk 3\r\nchunk 4\n\r\n"); + + assert!(codec.decode(buf).is_err()); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 2", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 3", chunk); + + // \r\n cause empty chunk + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("chunk 4", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + let chunk = codec.decode(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + buf.put_slice(b"k"); + assert_eq!(None, codec.decode(buf).unwrap()); + + let chunk = codec.decode_eof(buf).unwrap().unwrap(); + assert!( + chunk.len() <= MAX_LENGTH, + "{:?}.len() <= {:?}", + chunk, + MAX_LENGTH + ); + assert_eq!("k", chunk); + + assert_eq!(None, codec.decode(buf).unwrap()); + assert_eq!(None, codec.decode_eof(buf).unwrap()); + + // Delimiter that's one character too long. This could cause an out of bounds + // error if we peek at the next characters using slice indexing. + buf.put_slice(b"aaabbbcc"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun() { + const MAX_LENGTH: usize = 7; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"ong\n"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b"chunk 2"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b","); + assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_max_length_underrun_twice() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too very l"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"aaaaaaaaaaaaaaaaaaaaaaa"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\nshort\n"); + assert_eq!("short", codec.decode(buf).unwrap().unwrap()); +} +#[test] +fn any_delimiter_decoder_max_length_bursts() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too l"); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"ong\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_big_burst() { + const MAX_LENGTH: usize = 11; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"chunk "); + assert_eq!(None, codec.decode(buf).unwrap()); + buf.put_slice(b"too long!\n"); + assert!(codec.decode(buf).is_err()); +} + +#[test] +fn any_delimiter_decoder_max_length_delimiter_between_decodes() { + const MAX_LENGTH: usize = 5; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"hello"); + assert_eq!(None, codec.decode(buf).unwrap()); + + buf.put_slice(b",world"); + assert_eq!("hello", codec.decode(buf).unwrap().unwrap()); +} + +#[test] +fn any_delimiter_decoder_discard_repeat() { + const MAX_LENGTH: usize = 1; + + let mut codec = + AnyDelimiterCodec::new_with_max_length(b",;\n\r".to_vec(), b",".to_vec(), MAX_LENGTH); + let buf = &mut BytesMut::new(); + + buf.reserve(200); + buf.put_slice(b"aa"); + assert!(codec.decode(buf).is_err()); + buf.put_slice(b"a"); + assert_eq!(None, codec.decode(buf).unwrap()); +} + +#[test] +fn any_delimiter_encoder() { + let mut codec = AnyDelimiterCodec::new(b",".to_vec(), b";--;".to_vec()); + let mut buf = BytesMut::new(); + + codec.encode("chunk 1", &mut buf).unwrap(); + assert_eq!("chunk 1;--;", buf); + + codec.encode("chunk 2", &mut buf).unwrap(); + assert_eq!("chunk 1;--;chunk 2;--;", buf); +} diff --git a/tests/context.rs b/tests/context.rs new file mode 100644 index 0000000..2ec8258 --- /dev/null +++ b/tests/context.rs @@ -0,0 +1,25 @@ +#![cfg(feature = "rt")] +#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads +#![warn(rust_2018_idioms)] + +use tokio::runtime::Builder; +use tokio::time::*; +use tokio_util::context::RuntimeExt; + +#[test] +fn tokio_context_with_another_runtime() { + let rt1 = Builder::new_multi_thread() + .worker_threads(1) + // no timer! + .build() + .unwrap(); + let rt2 = Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + // Without the `HandleExt.wrap()` there would be a panic because there is + // no timer running, since it would be referencing runtime r1. + rt1.block_on(rt2.wrap(async move { sleep(Duration::from_millis(2)).await })); +} diff --git a/tests/framed.rs b/tests/framed.rs new file mode 100644 index 0000000..ec8cdf0 --- /dev/null +++ b/tests/framed.rs @@ -0,0 +1,152 @@ +#![warn(rust_2018_idioms)] + +use tokio_stream::StreamExt; +use tokio_test::assert_ok; +use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; + +use bytes::{Buf, BufMut, BytesMut}; +use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +const INITIAL_CAPACITY: usize = 8 * 1024; + +/// Encode and decode u32 values. +#[derive(Default)] +struct U32Codec { + read_bytes: usize, +} + +impl Decoder for U32Codec { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + self.read_bytes += 4; + Ok(Some(n)) + } +} + +impl Encoder for U32Codec { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +/// Encode and decode u64 values. +#[derive(Default)] +struct U64Codec { + read_bytes: usize, +} + +impl Decoder for U64Codec { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + self.read_bytes += 8; + Ok(Some(n)) + } +} + +impl Encoder for U64Codec { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +/// This value should never be used +struct DontReadIntoThis; + +impl Read for DontReadIntoThis { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Err(io::Error::new( + io::ErrorKind::Other, + "Read into something you weren't supposed to.", + )) + } +} + +impl tokio::io::AsyncRead for DontReadIntoThis { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + unreachable!() + } +} + +#[tokio::test] +async fn can_read_from_existing_buf() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); +} + +#[tokio::test] +async fn can_read_from_existing_buf_after_codec_changed() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84][..]); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); + + let mut framed = framed.map_codec(|codec| U64Codec { + read_bytes: codec.read_bytes, + }); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 84); + assert_eq!(framed.codec().read_bytes, 12); +} + +#[test] +fn external_buf_grows_to_init() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&[0, 0, 0, 42][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY); +} + +#[test] +fn external_buf_does_not_shrink() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from(&vec![0; INITIAL_CAPACITY * 2][..]); + + let framed = Framed::from_parts(parts); + let FramedParts { read_buf, .. } = framed.into_parts(); + + assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); +} diff --git a/tests/framed_read.rs b/tests/framed_read.rs new file mode 100644 index 0000000..2a9e27e --- /dev/null +++ b/tests/framed_read.rs @@ -0,0 +1,339 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_test::assert_ready; +use tokio_test::task; +use tokio_util::codec::{Decoder, FramedRead}; + +use bytes::{Buf, BytesMut}; +use futures::Stream; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Decoder; + +impl Decoder for U32Decoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 4 { + return Ok(None); + } + + let n = buf.split_to(4).get_u32(); + Ok(Some(n)) + } +} + +struct U64Decoder; + +impl Decoder for U64Decoder { + type Item = u64; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 8 { + return Ok(None); + } + + let n = buf.split_to(8).get_u64(); + Ok(Some(n)) + } +} + +#[test] +fn read_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_across_packets() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_multi_frame_in_packet_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0x04); + + let mut framed = framed.map_decoder(|_| U64Decoder); + assert_read!(pin!(framed).poll_next(cx), 0x08); + + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00".to_vec()), + Ok(b"\x00\x00\x00\x01".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_partial_then_not_ready() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_err() { + let mut task = task::spawn(()); + let mock = mock! { + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn read_partial_would_block_then_err() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00".to_vec()), + Err(io::Error::new(io::ErrorKind::WouldBlock, "")), + Err(io::Error::new(io::ErrorKind::Other, "")), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); +} + +#[test] +fn huge_size() { + let mut task = task::spawn(()); + let data = &[0; 32 * 1024][..]; + let mut framed = FramedRead::new(data, BigDecoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); + + struct BigDecoder; + + impl Decoder for BigDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> io::Result> { + if buf.len() < 32 * 1024 { + return Ok(None); + } + buf.advance(32 * 1024); + Ok(Some(0)) + } + } +} + +#[test] +fn data_remaining_is_error() { + let mut task = task::spawn(()); + let slice = &[0; 5][..]; + let mut framed = FramedRead::new(slice, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); + }); +} + +#[test] +fn multi_frames_on_eof() { + let mut task = task::spawn(()); + struct MyDecoder(Vec); + + impl Decoder for MyDecoder { + type Item = u32; + type Error = io::Error; + + fn decode(&mut self, _buf: &mut BytesMut) -> io::Result> { + unreachable!(); + } + + fn decode_eof(&mut self, _buf: &mut BytesMut) -> io::Result> { + if self.0.is_empty() { + return Ok(None); + } + + Ok(Some(self.0.remove(0))) + } + } + + let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +#[test] +fn read_eof_then_resume() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x01".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x02".to_vec()), + Ok(b"".to_vec()), + Ok(b"\x00\x00\x00\x03".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + +// ===== Mock ====== + +struct Mock { + calls: VecDeque>>, +} + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + use io::ErrorKind::WouldBlock; + + match self.calls.pop_front() { + Some(Ok(data)) => { + debug_assert!(buf.remaining() >= data.len()); + buf.put_slice(&data); + Ready(Ok(())) + } + Some(Err(ref e)) if e.kind() == WouldBlock => Pending, + Some(Err(e)) => Ready(Err(e)), + None => Ready(Ok(())), + } + } +} diff --git a/tests/framed_stream.rs b/tests/framed_stream.rs new file mode 100644 index 0000000..76d8af7 --- /dev/null +++ b/tests/framed_stream.rs @@ -0,0 +1,38 @@ +use futures_core::stream::Stream; +use std::{io, pin::Pin}; +use tokio_test::{assert_ready, io::Builder, task}; +use tokio_util::codec::{BytesCodec, FramedRead}; + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +#[tokio::test] +async fn return_none_after_error() { + let mut io = FramedRead::new( + Builder::new() + .read(b"abcdef") + .read_error(io::Error::new(io::ErrorKind::Other, "Resource errored out")) + .read(b"more data") + .build(), + BytesCodec::new(), + ); + + let mut task = task::spawn(()); + + task.enter(|cx, _| { + assert_read!(pin!(io).poll_next(cx), b"abcdef".to_vec()); + assert!(assert_ready!(pin!(io).poll_next(cx)).unwrap().is_err()); + assert!(assert_ready!(pin!(io).poll_next(cx)).is_none()); + assert_read!(pin!(io).poll_next(cx), b"more data".to_vec()); + }) +} diff --git a/tests/framed_write.rs b/tests/framed_write.rs new file mode 100644 index 0000000..39091c0 --- /dev/null +++ b/tests/framed_write.rs @@ -0,0 +1,212 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::AsyncWrite; +use tokio_test::{assert_ready, task}; +use tokio_util::codec::{Encoder, FramedWrite}; + +use bytes::{BufMut, BytesMut}; +use futures_sink::Sink; +use std::collections::VecDeque; +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + +struct U32Encoder; + +impl Encoder for U32Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(4); + dst.put_u32(item); + Ok(()) + } +} + +struct U64Encoder; + +impl Encoder for U64Encoder { + type Error = io::Error; + + fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> { + // Reserve space + dst.reserve(8); + dst.put_u64(item); + Ok(()) + } +} + +#[test] +fn write_multi_frame_in_packet() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(1).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(2).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_multi_frame_after_codec_changed() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x04).is_ok()); + + let mut framed = framed.map_encoder(|_| U64Encoder); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x08).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + +#[test] +fn write_hits_backpressure() { + const ITER: usize = 2 * 1024; + + let mut mock = mock! { + // Block the `ITER*2`th write + Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")), + Ok(b"".to_vec()), + }; + + for i in 0..=ITER * 2 { + let mut b = BytesMut::with_capacity(4); + b.put_u32(i as u32); + + // Append to the end + match mock.calls.back_mut().unwrap() { + Ok(ref mut data) => { + // Write in 2kb chunks + if data.len() < ITER { + data.extend_from_slice(&b[..]); + continue; + } // else fall through and create a new buffer + } + _ => unreachable!(), + } + + // Push a new chunk + mock.calls.push_back(Ok(b[..].to_vec())); + } + // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer + assert_eq!(mock.calls.len(), 10); + + let mut task = task::spawn(()); + let mut framed = FramedWrite::new(mock, U32Encoder); + framed.set_backpressure_boundary(ITER * 8); + task.enter(|cx, _| { + // Send 16KB. This fills up FramedWrite buffer + for i in 0..ITER * 2 { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(i as u32).is_ok()); + } + + // Now we poll_ready which forces a flush. The mock pops the front message + // and decides to block. + assert!(pin!(framed).poll_ready(cx).is_pending()); + + // We poll again, forcing another flush, which this time succeeds + // The whole 16KB buffer is flushed + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + + // Send more data. This matches the final message expected by the mock + assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok()); + + // Flush the rest of the buffer + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + // Ensure the mock is empty + assert_eq!(0, framed.get_ref().calls.len()); + }) +} + +// // ===== Mock ====== + +struct Mock { + calls: VecDeque>>, +} + +impl Write for Mock { + fn write(&mut self, src: &[u8]) -> io::Result { + match self.calls.pop_front() { + Some(Ok(data)) => { + assert!(src.len() >= data.len()); + assert_eq!(&data[..], &src[..data.len()]); + Ok(data.len()) + } + Some(Err(e)) => Err(e), + None => panic!("unexpected write; {:?}", src), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match Pin::get_mut(self).write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match Pin::get_mut(self).flush() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unimplemented!() + } +} diff --git a/tests/io_inspect.rs b/tests/io_inspect.rs new file mode 100644 index 0000000..e6319af --- /dev/null +++ b/tests/io_inspect.rs @@ -0,0 +1,194 @@ +use futures::future::poll_fn; +use std::{ + io::IoSlice, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio_util::io::{InspectReader, InspectWriter}; + +/// An AsyncRead implementation that works byte-by-byte, to catch out callers +/// who don't allow for `buf` being part-filled before the call +struct SmallReader { + contents: Vec, +} + +impl Unpin for SmallReader {} + +impl AsyncRead for SmallReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(byte) = self.contents.pop() { + buf.put_slice(&[byte]) + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn read_tee() { + let contents = b"This could be really long, you know".to_vec(); + let reader = SmallReader { + contents: contents.clone(), + }; + let mut altout: Vec = Vec::new(); + let mut teeout = Vec::new(); + { + let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); + tee.read_to_end(&mut teeout).await.unwrap(); + } + assert_eq!(teeout, altout); + assert_eq!(altout.len(), contents.len()); +} + +/// An AsyncWrite implementation that works byte-by-byte for poll_write, and +/// that reads the whole of the first buffer plus one byte from the second in +/// poll_write_vectored. +/// +/// This is designed to catch bugs in handling partially written buffers +#[derive(Debug)] +struct SmallWriter { + contents: Vec, +} + +impl Unpin for SmallWriter {} + +impl AsyncWrite for SmallWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // Just write one byte at a time + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.contents.push(buf[0]); + Poll::Ready(Ok(1)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // Write all of the first buffer, then one byte from the second buffer + // This should trip up anything that doesn't correctly handle multiple + // buffers. + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + let mut written_len = bufs[0].len(); + self.contents.extend_from_slice(&bufs[0]); + + if bufs.len() > 1 { + let buf = bufs[1]; + if !buf.is_empty() { + written_len += 1; + self.contents.push(buf[0]); + } + } + Poll::Ready(Ok(written_len)) + } + + fn is_write_vectored(&self) -> bool { + true + } +} + +#[tokio::test] +async fn write_tee() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + { + let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); + tee.write_all(b"A testing string, very testing") + .await + .unwrap(); + } + assert_eq!(altout, writeout.contents); +} + +// This is inefficient, but works well enough for test use. +// If you want something similar for real code, you'll want to avoid all the +// fun of manipulating `bufs` - ideally, by the time you read this, +// IoSlice::advance_slices will be stable, and you can use that. +async fn write_all_vectored( + mut writer: W, + mut bufs: Vec>, +) -> Result { + let mut res = 0; + while !bufs.is_empty() { + let mut written = poll_fn(|cx| { + let bufs: Vec = bufs.iter().map(|v| IoSlice::new(v)).collect(); + Pin::new(&mut writer).poll_write_vectored(cx, &bufs) + }) + .await?; + res += written; + while written > 0 { + let buf_len = bufs[0].len(); + if buf_len <= written { + bufs.remove(0); + written -= buf_len; + } else { + let buf = &mut bufs[0]; + let drain_len = written.min(buf.len()); + buf.drain(..drain_len); + written -= drain_len; + } + } + } + Ok(res) +} + +#[tokio::test] +async fn write_tee_vectored() { + let mut altout: Vec = Vec::new(); + let mut writeout = SmallWriter { + contents: Vec::new(), + }; + let original = b"A very long string split up"; + let bufs: Vec> = original + .split(|b| b.is_ascii_whitespace()) + .map(Vec::from) + .collect(); + assert!(bufs.len() > 1); + let expected: Vec = { + let mut out = Vec::new(); + for item in &bufs { + out.extend_from_slice(item) + } + out + }; + { + let mut bufcount = 0; + let tee = InspectWriter::new(&mut writeout, |bytes| { + bufcount += 1; + altout.extend(bytes) + }); + + assert!(tee.is_write_vectored()); + + write_all_vectored(tee, bufs.clone()).await.unwrap(); + + assert!(bufcount >= bufs.len()); + } + assert_eq!(altout, writeout.contents); + assert_eq!(writeout.contents, expected); +} diff --git a/tests/io_reader_stream.rs b/tests/io_reader_stream.rs new file mode 100644 index 0000000..e30cd85 --- /dev/null +++ b/tests/io_reader_stream.rs @@ -0,0 +1,65 @@ +#![warn(rust_2018_idioms)] + +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio_stream::StreamExt; + +/// produces at most `remaining` zeros, that returns error. +/// each time it reads at most 31 byte. +struct Reader { + remaining: usize, +} + +impl AsyncRead for Reader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = Pin::into_inner(self); + assert_ne!(buf.remaining(), 0); + if this.remaining > 0 { + let n = std::cmp::min(this.remaining, buf.remaining()); + let n = std::cmp::min(n, 31); + for x in &mut buf.initialize_unfilled_to(n)[..n] { + *x = 0; + } + buf.advance(n); + this.remaining -= n; + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(std::io::Error::from_raw_os_error(22))) + } + } +} + +#[tokio::test] +async fn correct_behavior_on_errors() { + let reader = Reader { remaining: 8000 }; + let mut stream = tokio_util::io::ReaderStream::new(reader); + let mut zeros_received = 0; + let mut had_error = false; + loop { + let item = stream.next().await.unwrap(); + println!("{:?}", item); + match item { + Ok(bytes) => { + let bytes = &*bytes; + for byte in bytes { + assert_eq!(*byte, 0); + zeros_received += 1; + } + } + Err(_) => { + assert!(!had_error); + had_error = true; + break; + } + } + } + + assert!(had_error); + assert_eq!(zeros_received, 8000); + assert!(stream.next().await.is_none()); +} diff --git a/tests/io_sink_writer.rs b/tests/io_sink_writer.rs new file mode 100644 index 0000000..e76870b --- /dev/null +++ b/tests/io_sink_writer.rs @@ -0,0 +1,72 @@ +#![warn(rust_2018_idioms)] + +use bytes::Bytes; +use futures_util::SinkExt; +use std::io::{self, Error, ErrorKind}; +use tokio::io::AsyncWriteExt; +use tokio_util::codec::{Encoder, FramedWrite}; +use tokio_util::io::{CopyToBytes, SinkWriter}; +use tokio_util::sync::PollSender; + +#[tokio::test] +async fn test_copied_sink_writer() -> Result<(), Error> { + // Construct a channel pair to send data across and wrap a pollable sink. + // Note that the sink must mimic a writable object, e.g. have `std::io::Error` + // as its error type. + // As `PollSender` requires an owned copy of the buffer, we wrap it additionally + // with a `CopyToBytes` helper. + let (tx, mut rx) = tokio::sync::mpsc::channel::(1); + let mut writer = SinkWriter::new(CopyToBytes::new( + PollSender::new(tx).sink_map_err(|_| io::Error::from(ErrorKind::BrokenPipe)), + )); + + // Write data to our interface... + let data: [u8; 4] = [1, 2, 3, 4]; + let _ = writer.write(&data).await; + + // ... and receive it. + assert_eq!(data.to_vec(), rx.recv().await.unwrap().to_vec()); + + Ok(()) +} + +/// A trivial encoder. +struct SliceEncoder; + +impl SliceEncoder { + fn new() -> Self { + Self {} + } +} + +impl<'a> Encoder<&'a [u8]> for SliceEncoder { + type Error = Error; + + fn encode(&mut self, item: &'a [u8], dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + // This is where we'd write packet headers, lengths, etc. in a real encoder. + // For simplicity and demonstration purposes, we just pack a copy of + // the slice at the end of a buffer. + dst.extend_from_slice(item); + Ok(()) + } +} + +#[tokio::test] +async fn test_direct_sink_writer() -> Result<(), Error> { + // We define a framed writer which accepts byte slices + // and 'reverse' this construction immediately. + let framed_byte_lc = FramedWrite::new(Vec::new(), SliceEncoder::new()); + let mut writer = SinkWriter::new(framed_byte_lc); + + // Write multiple slices to the sink... + let _ = writer.write(&[1, 2, 3]).await; + let _ = writer.write(&[4, 5, 6]).await; + + // ... and compare it with the buffer. + assert_eq!( + writer.into_inner().write_buffer().to_vec().as_slice(), + &[1, 2, 3, 4, 5, 6] + ); + + Ok(()) +} diff --git a/tests/io_stream_reader.rs b/tests/io_stream_reader.rs new file mode 100644 index 0000000..5975994 --- /dev/null +++ b/tests/io_stream_reader.rs @@ -0,0 +1,35 @@ +#![warn(rust_2018_idioms)] + +use bytes::Bytes; +use tokio::io::AsyncReadExt; +use tokio_stream::iter; +use tokio_util::io::StreamReader; + +#[tokio::test] +async fn test_stream_reader() -> std::io::Result<()> { + let stream = iter(vec![ + std::io::Result::Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[0, 1, 2, 3])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[4, 5, 6, 7])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[8, 9, 10, 11])), + Ok(Bytes::from_static(&[])), + ]); + + let mut read = StreamReader::new(stream); + + let mut buf = [0; 5]; + read.read_exact(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + assert_eq!(read.read(&mut buf).await?, 3); + assert_eq!(&buf[..3], [5, 6, 7]); + + assert_eq!(read.read(&mut buf).await?, 4); + assert_eq!(&buf[..4], [8, 9, 10, 11]); + + assert_eq!(read.read(&mut buf).await?, 0); + + Ok(()) +} diff --git a/tests/io_sync_bridge.rs b/tests/io_sync_bridge.rs new file mode 100644 index 0000000..76bbd0b --- /dev/null +++ b/tests/io_sync_bridge.rs @@ -0,0 +1,62 @@ +#![cfg(feature = "io-util")] +#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads + +use std::error::Error; +use std::io::{Cursor, Read, Result as IoResult, Write}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio_util::io::SyncIoBridge; + +async fn test_reader_len( + r: impl AsyncRead + Unpin + Send + 'static, + expected_len: usize, +) -> IoResult<()> { + let mut r = SyncIoBridge::new(r); + let res = tokio::task::spawn_blocking(move || { + let mut buf = Vec::new(); + r.read_to_end(&mut buf)?; + Ok::<_, std::io::Error>(buf) + }) + .await?; + assert_eq!(res?.len(), expected_len); + Ok(()) +} + +#[tokio::test] +async fn test_async_read_to_sync() -> Result<(), Box> { + test_reader_len(tokio::io::empty(), 0).await?; + let buf = b"hello world"; + test_reader_len(Cursor::new(buf), buf.len()).await?; + Ok(()) +} + +#[tokio::test] +async fn test_async_write_to_sync() -> Result<(), Box> { + let mut dest = Vec::new(); + let src = b"hello world"; + let dest = tokio::task::spawn_blocking(move || -> Result<_, String> { + let mut w = SyncIoBridge::new(Cursor::new(&mut dest)); + std::io::copy(&mut Cursor::new(src), &mut w).map_err(|e| e.to_string())?; + Ok(dest) + }) + .await??; + assert_eq!(dest.as_slice(), src); + Ok(()) +} + +#[tokio::test] +async fn test_shutdown() -> Result<(), Box> { + let (s1, mut s2) = tokio::io::duplex(1024); + let (_rh, wh) = tokio::io::split(s1); + tokio::task::spawn_blocking(move || -> std::io::Result<_> { + let mut wh = SyncIoBridge::new(wh); + wh.write_all(b"hello")?; + wh.shutdown()?; + assert!(wh.write_all(b" world").is_err()); + Ok(()) + }) + .await??; + let mut buf = vec![]; + s2.read_to_end(&mut buf).await?; + assert_eq!(buf, b"hello"); + Ok(()) +} diff --git a/tests/length_delimited.rs b/tests/length_delimited.rs new file mode 100644 index 0000000..126e41b --- /dev/null +++ b/tests/length_delimited.rs @@ -0,0 +1,779 @@ +#![warn(rust_2018_idioms)] + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_test::task; +use tokio_test::{ + assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, +}; +use tokio_util::codec::*; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{pin_mut, Sink, Stream}; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::task::Poll::*; +use std::task::{Context, Poll}; + +macro_rules! mock { + ($($x:expr,)*) => {{ + let mut v = VecDeque::new(); + v.extend(vec![$($x),*]); + Mock { calls: v } + }}; +} + +macro_rules! assert_next_eq { + ($io:ident, $expect:expr) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => assert_eq!(v, $expect.as_ref()), + Some(Err(e)) => panic!("error = {:?}", e), + None => panic!("none"), + } + }); + }}; +} + +macro_rules! assert_next_pending { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(e))) => panic!("error = {:?}", e), + Ready(None) => panic!("done"), + Pending => {} + }); + }}; +} + +macro_rules! assert_next_err { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| match $io.as_mut().poll_next(cx) { + Ready(Some(Ok(v))) => panic!("value = {:?}", v), + Ready(Some(Err(_))) => {} + Ready(None) => panic!("done"), + Pending => panic!("pending"), + }); + }}; +} + +macro_rules! assert_done { + ($io:ident) => {{ + task::spawn(()).enter(|cx, _| { + let res = assert_ready!($io.as_mut().poll_next(cx)); + match res { + Some(Ok(v)) => panic!("value = {:?}", v), + Some(Err(e)) => panic!("error = {:?}", e), + None => {} + } + }); + }}; +} + +#[test] +fn read_empty_io_yields_nothing() { + let io = Box::pin(FramedRead::new(mock!(), LengthDelimitedCodec::new())); + pin_mut!(io); + + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_read(mock! { + data(b"\x09\x00\x00\x00abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_frame_one_packet_native_endian() { + let d = if cfg!(target_endian = "big") { + b"\x00\x00\x00\x09abcdefghi" + } else { + b"\x09\x00\x00\x00abcdefghi" + }; + let io = length_delimited::Builder::new() + .native_endian() + .new_read(mock! { + data(d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet() { + let mut d: Vec = vec![]; + d.extend_from_slice(b"\x00\x00\x00\x09abcdefghi"); + d.extend_from_slice(b"\x00\x00\x00\x03123"); + d.extend_from_slice(b"\x00\x00\x00\x0bhello world"); + + let io = FramedRead::new( + mock! { + data(&d), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + data(b"\x00\x09abc"), + data(b"defghi"), + data(b"\x00\x00\x00\x0312"), + data(b"3\x00\x00\x00\x0bhello world"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_multi_frame_multi_packet_wait() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + Pending, + data(b"\x00\x09abc"), + Pending, + data(b"defghi"), + Pending, + data(b"\x00\x00\x00\x0312"), + Pending, + data(b"3\x00\x00\x00\x0bhello world"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"abcdefghi"); + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_next_pending!(io); + assert_done!(io); +} + +#[test] +fn read_incomplete_head() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00"), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_incomplete_head_multi() { + let io = FramedRead::new( + mock! { + Pending, + data(b"\x00"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_incomplete_payload() { + let io = FramedRead::new( + mock! { + data(b"\x00\x00\x00\x09ab"), + Pending, + data(b"cd"), + Pending, + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + assert_next_pending!(io); + assert_next_pending!(io); + assert_next_err!(io); +} + +#[test] +fn read_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcdefghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + io.decoder_mut().set_max_frame_length(5); + assert_next_err!(io); +} + +#[test] +fn read_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_read(mock! { + data(b"\x00\x00\x00\x09abcd"), + Pending, + data(b"efghi"), + data(b"\x00\x00\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_pending!(io); + io.decoder_mut().set_max_frame_length(5); + assert_next_eq!(io, b"abcdefghi"); + assert_next_err!(io); +} + +#[test] +fn read_one_byte_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_read(mock! { + data(b"\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_header_offset() { + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(4) + .new_read(mock! { + data(b"zzzz\x00\x09abcdefghi"), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_skip_none_adjusted() { + let mut d: Vec = vec![]; + d.extend_from_slice(b"xx\x00\x09abcdefghi"); + d.extend_from_slice(b"yy\x00\x03123"); + d.extend_from_slice(b"zz\x00\x0bhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_field_offset(2) + .num_skip(0) + .length_adjustment(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"xx\x00\x09abcdefghi"); + assert_next_eq!(io, b"yy\x00\x03123"); + assert_next_eq!(io, b"zz\x00\x0bhello world"); + assert_done!(io); +} + +#[test] +fn read_single_frame_length_adjusted() { + let mut d: Vec = vec![]; + d.extend_from_slice(b"\x00\x00\x0b\x0cHello world"); + + let io = length_delimited::Builder::new() + .length_field_offset(0) + .length_field_length(3) + .length_adjustment(0) + .num_skip(4) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"Hello world"); + assert_done!(io); +} + +#[test] +fn read_single_multi_frame_one_packet_length_includes_head() { + let mut d: Vec = vec![]; + d.extend_from_slice(b"\x00\x0babcdefghi"); + d.extend_from_slice(b"\x00\x05123"); + d.extend_from_slice(b"\x00\x0dhello world"); + + let io = length_delimited::Builder::new() + .length_field_length(2) + .length_adjustment(-2) + .new_read(mock! { + data(&d), + }); + pin_mut!(io); + + assert_next_eq!(io, b"abcdefghi"); + assert_next_eq!(io, b"123"); + assert_next_eq!(io, b"hello world"); + assert_done!(io); +} + +#[test] +fn write_single_frame_length_adjusted() { + let io = length_delimited::Builder::new() + .length_adjustment(-2) + .new_write(mock! { + data(b"\x00\x00\x00\x0b"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_nothing_yields_nothing() { + let io = FramedWrite::new(mock!(), LengthDelimitedCodec::new()); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.poll_flush(cx)); + }); +} + +#[test] +fn write_single_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_one_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + data(b"\x00\x00\x00\x03"), + data(b"123"), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_multi_frame_multi_packet() { + let io = FramedWrite::new( + mock! { + data(b"\x00\x00\x00\x09"), + data(b"abcdefghi"), + flush(), + data(b"\x00\x00\x00\x03"), + data(b"123"), + flush(), + data(b"\x00\x00\x00\x0b"), + data(b"hello world"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("123"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("hello world"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_would_block() { + let io = FramedWrite::new( + mock! { + Pending, + data(b"\x00\x00"), + Pending, + data(b"\x00\x09"), + data(b"abcdefghi"), + flush(), + }, + LengthDelimitedCodec::new(), + ); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + assert_pending!(io.as_mut().poll_flush(cx)); + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_little_endian() { + let io = length_delimited::Builder::new() + .little_endian() + .new_write(mock! { + data(b"\x09\x00\x00\x00"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_single_frame_with_short_length_field() { + let io = length_delimited::Builder::new() + .length_field_length(1) + .new_write(mock! { + data(b"\x09"), + data(b"abcdefghi"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdefghi"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_max_frame_len() { + let io = length_delimited::Builder::new() + .max_frame_length(5) + .new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_at_rest() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"abcdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_update_max_frame_len_in_flight() { + let io = length_delimited::Builder::new().new_write(mock! { + data(b"\x00\x00\x00\x06"), + data(b"ab"), + Pending, + data(b"cdef"), + flush(), + }); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_pending!(io.as_mut().poll_flush(cx)); + + io.encoder_mut().set_max_frame_length(5); + + assert_ready_ok!(io.as_mut().poll_flush(cx)); + + assert_err!(io.as_mut().start_send(Bytes::from("abcdef"))); + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn write_zero() { + let io = length_delimited::Builder::new().new_write(mock! {}); + pin_mut!(io); + + task::spawn(()).enter(|cx, _| { + assert_ready_ok!(io.as_mut().poll_ready(cx)); + assert_ok!(io.as_mut().start_send(Bytes::from("abcdef"))); + + assert_ready_err!(io.as_mut().poll_flush(cx)); + + assert!(io.get_ref().calls.is_empty()); + }); +} + +#[test] +fn encode_overflow() { + // Test reproducing tokio-rs/tokio#681. + let mut codec = length_delimited::Builder::new().new_codec(); + let mut buf = BytesMut::with_capacity(1024); + + // Put some data into the buffer without resizing it to hold more. + let some_as = std::iter::repeat(b'a').take(1024).collect::>(); + buf.put_slice(&some_as[..]); + + // Trying to encode the length header should resize the buffer if it won't fit. + codec.encode(Bytes::from("hello"), &mut buf).unwrap(); +} + +// ===== Test utils ===== + +struct Mock { + calls: VecDeque>>, +} + +enum Op { + Data(Vec), + Flush, +} + +use self::Op::*; + +impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + dst: &mut ReadBuf<'_>, + ) -> Poll> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + debug_assert!(dst.remaining() >= data.len()); + dst.put_slice(&data); + Ready(Ok(())) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } +} + +impl AsyncWrite for Mock { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + src: &[u8], + ) -> Poll> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Data(data)))) => { + let len = data.len(); + assert!(src.len() >= len, "expect={:?}; actual={:?}", data, src); + assert_eq!(&data[..], &src[..len]); + Ready(Ok(len)) + } + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(0)), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match self.calls.pop_front() { + Some(Ready(Ok(Op::Flush))) => Ready(Ok(())), + Some(Ready(Ok(_))) => panic!(), + Some(Ready(Err(e))) => Ready(Err(e)), + Some(Pending) => Pending, + None => Ready(Ok(())), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ready(Ok(())) + } +} + +impl<'a> From<&'a [u8]> for Op { + fn from(src: &'a [u8]) -> Op { + Op::Data(src.into()) + } +} + +impl From> for Op { + fn from(src: Vec) -> Op { + Op::Data(src) + } +} + +fn data(bytes: &[u8]) -> Poll> { + Ready(Ok(bytes.into())) +} + +fn flush() -> Poll> { + Ready(Ok(Flush)) +} diff --git a/tests/mpsc.rs b/tests/mpsc.rs new file mode 100644 index 0000000..a3c164d --- /dev/null +++ b/tests/mpsc.rs @@ -0,0 +1,239 @@ +use futures::future::poll_fn; +use tokio::sync::mpsc::channel; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok}; +use tokio_util::sync::PollSender; + +#[tokio::test] +async fn simple() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + drop(recv); + + send.send_item(42).unwrap(); +} + +#[tokio::test] +async fn repeated_poll_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(recv.recv().await.unwrap(), 1); +} + +#[tokio::test] +async fn abort_send() { + let (send, mut recv) = channel(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + for i in 1..=3i32 { + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(i).unwrap(); + } + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + assert_eq!(recv.recv().await.unwrap(), 1); + assert!(reserve.is_woken()); + assert_ready_ok!(reserve.poll()); + + let mut send2_send = spawn(send2.send(5)); + assert_pending!(send2_send.poll()); + assert!(send.abort_send()); + assert!(send2_send.is_woken()); + assert_ready_ok!(send2_send.poll()); + + assert_eq!(recv.recv().await.unwrap(), 2); + assert_eq!(recv.recv().await.unwrap(), 3); + assert_eq!(recv.recv().await.unwrap(), 5); +} + +#[tokio::test] +async fn close_sender_last() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_not_last() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + let send2 = send.get_ref().cloned().unwrap(); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + drop(send2); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); +} + +#[tokio::test] +async fn close_sender_before_reserve() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + send.close(); + + assert!(recv_task.is_woken()); + assert!(assert_ready!(recv_task.poll()).is_none()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_pending_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert!(recv_task.is_woken()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + drop(reserve); + + send.close(); + + assert!(send.is_closed()); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[tokio::test] +async fn close_sender_after_successful_reserve() { + let (send, mut recv) = channel::(3); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + drop(reserve); + + send.close(); + assert!(send.is_closed()); + assert!(!recv_task.is_woken()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); +} + +#[tokio::test] +async fn abort_send_after_pending_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + assert_eq!(send.get_ref().unwrap().capacity(), 0); + assert!(!send.abort_send()); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); +} + +#[tokio::test] +async fn abort_send_after_successful_reserve() { + let (send, mut recv) = channel::(1); + let mut send = PollSender::new(send); + + let mut recv_task = spawn(recv.recv()); + assert_pending!(recv_task.poll()); + + assert_eq!(send.get_ref().unwrap().capacity(), 1); + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + assert_eq!(send.get_ref().unwrap().capacity(), 0); + + assert!(send.abort_send()); + assert_eq!(send.get_ref().unwrap().capacity(), 1); +} + +#[tokio::test] +async fn closed_when_receiver_drops() { + let (send, _) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_err!(reserve.poll()); +} + +#[should_panic] +#[test] +fn start_send_panics_when_idle() { + let (send, _) = channel::(3); + let mut send = PollSender::new(send); + + send.send_item(1).unwrap(); +} + +#[should_panic] +#[test] +fn start_send_panics_when_acquiring() { + let (send, _) = channel::(1); + let mut send = PollSender::new(send); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_ready_ok!(reserve.poll()); + send.send_item(1).unwrap(); + + let mut reserve = spawn(poll_fn(|cx| send.poll_reserve(cx))); + assert_pending!(reserve.poll()); + send.send_item(2).unwrap(); +} diff --git a/tests/panic.rs b/tests/panic.rs new file mode 100644 index 0000000..e4fcb47 --- /dev/null +++ b/tests/panic.rs @@ -0,0 +1,226 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi doesn't support panic recovery + +use parking_lot::{const_mutex, Mutex}; +use std::error::Error; +use std::panic; +use std::sync::Arc; +use tokio::runtime::Runtime; +use tokio::sync::mpsc::channel; +use tokio::time::{Duration, Instant}; +use tokio_test::task; +use tokio_util::io::SyncIoBridge; +use tokio_util::sync::PollSender; +use tokio_util::task::LocalPoolHandle; +use tokio_util::time::DelayQueue; + +// Taken from tokio-util::time::wheel, if that changes then +const MAX_DURATION_MS: u64 = (1 << (36)) - 1; + +fn test_panic(func: Func) -> Option { + static PANIC_MUTEX: Mutex<()> = const_mutex(()); + + { + let _guard = PANIC_MUTEX.lock(); + let panic_file: Arc>> = Arc::new(Mutex::new(None)); + + let prev_hook = panic::take_hook(); + { + let panic_file = panic_file.clone(); + panic::set_hook(Box::new(move |panic_info| { + let panic_location = panic_info.location().unwrap(); + panic_file + .lock() + .clone_from(&Some(panic_location.file().to_string())); + })); + } + + let result = panic::catch_unwind(func); + // Return to the previously set panic hook (maybe default) so that we get nice error + // messages in the tests. + panic::set_hook(prev_hook); + + if result.is_err() { + panic_file.lock().clone() + } else { + None + } + } +} + +#[test] +fn sync_bridge_new_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = SyncIoBridge::new(tokio::io::empty()); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn poll_sender_send_item_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let (send, _) = channel::(3); + let mut send = PollSender::new(send); + + let _ = send.send_item(42); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] + +fn local_pool_handle_new_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = LocalPoolHandle::new(0); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] + +fn local_pool_handle_spawn_pinned_by_idx_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + + rt.block_on(async { + let handle = LocalPoolHandle::new(2); + handle.spawn_pinned_by_idx(|| async { "test" }, 3); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} +#[test] +fn delay_queue_insert_at_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + //let st = std::time::Instant::from(SystemTime::UNIX_EPOCH); + let _k = queue.insert_at( + "1", + Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_insert_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let _k = queue.insert("1", Duration::from_millis(MAX_DURATION_MS + 1)); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_remove_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.remove(&key); + queue.remove(&key); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reset_at_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.reset_at( + &key, + Instant::now() + Duration::from_millis(MAX_DURATION_MS + 1), + ); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reset_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::with_capacity(3)); + + let key = queue.insert_at("1", Instant::now()); + queue.reset(&key, Duration::from_millis(MAX_DURATION_MS + 1)); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn delay_queue_reserve_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let rt = basic(); + rt.block_on(async { + let mut queue = task::spawn(DelayQueue::::with_capacity(3)); + + queue.reserve((1 << 30) as usize); + }); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +fn basic() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +} diff --git a/tests/poll_semaphore.rs b/tests/poll_semaphore.rs new file mode 100644 index 0000000..28beca1 --- /dev/null +++ b/tests/poll_semaphore.rs @@ -0,0 +1,84 @@ +use std::future::Future; +use std::sync::Arc; +use std::task::Poll; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; + +type SemRet = Option; + +fn semaphore_poll( + sem: &mut PollSemaphore, +) -> tokio_test::task::Spawn + '_> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx)); + tokio_test::task::spawn(fut) +} + +fn semaphore_poll_many( + sem: &mut PollSemaphore, + permits: u32, +) -> tokio_test::task::Spawn + '_> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits)); + tokio_test::task::spawn(fut) +} + +#[tokio::test] +async fn it_works() { + let sem = Arc::new(Semaphore::new(1)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit = sem.acquire().await.unwrap(); + let mut poll = semaphore_poll(&mut poll_sem); + assert!(poll.poll().is_pending()); + drop(permit); + + assert!(matches!(poll.poll(), Poll::Ready(Some(_)))); + drop(poll); + + sem.close(); + + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + + // Check that it is fused. + assert!(semaphore_poll(&mut poll_sem).await.is_none()); + assert!(semaphore_poll(&mut poll_sem).await.is_none()); +} + +#[tokio::test] +async fn can_acquire_many_permits() { + let sem = Arc::new(Semaphore::new(4)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit1 = semaphore_poll(&mut poll_sem).poll(); + assert!(matches!(permit1, Poll::Ready(Some(_)))); + + let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll(); + assert!(matches!(permit2, Poll::Ready(Some(_)))); + + assert_eq!(sem.available_permits(), 1); + + drop(permit2); + + let mut permit4 = semaphore_poll_many(&mut poll_sem, 4); + assert!(permit4.poll().is_pending()); + + drop(permit1); + + let permit4 = permit4.poll(); + assert!(matches!(permit4, Poll::Ready(Some(_)))); + assert_eq!(sem.available_permits(), 0); +} + +#[tokio::test] +async fn can_poll_different_amounts_of_permits() { + let sem = Arc::new(Semaphore::new(4)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); + + let permit = sem.acquire_many(4).await.unwrap(); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending()); + drop(permit); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); +} diff --git a/tests/reusable_box.rs b/tests/reusable_box.rs new file mode 100644 index 0000000..e9c8a25 --- /dev/null +++ b/tests/reusable_box.rs @@ -0,0 +1,85 @@ +use futures::future::FutureExt; +use std::alloc::Layout; +use std::future::Future; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::rc::Rc; +use std::task::{Context, Poll}; +use tokio_util::sync::ReusableBoxFuture; + +#[test] +// Clippy false positive; it's useful to be able to test the trait impls for any lifetime +#[allow(clippy::extra_unused_lifetimes)] +fn traits<'a>() { + fn assert_traits() {} + // Use a type that is !Unpin + assert_traits::>(); + // Use a type that is !Send + !Sync + assert_traits::>>(); +} + +#[test] +fn test_different_futures() { + let fut = async move { 10 }; + // Not zero sized! + assert_eq!(Layout::for_value(&fut).size(), 1); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(10)); + + b.try_set(async move { 20 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(20)); + + b.try_set(async move { 30 }) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(30)); +} + +#[test] +fn test_different_sizes() { + let fut1 = async move { 10 }; + let val = [0u32; 1000]; + let fut2 = async move { val[0] }; + let fut3 = ZeroSizedFuture {}; + + assert_eq!(Layout::for_value(&fut1).size(), 1); + assert_eq!(Layout::for_value(&fut2).size(), 4004); + assert_eq!(Layout::for_value(&fut3).size(), 0); + + let mut b = ReusableBoxFuture::new(fut1); + assert_eq!(b.get_pin().now_or_never(), Some(10)); + b.set(fut2); + assert_eq!(b.get_pin().now_or_never(), Some(0)); + b.set(fut3); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} + +struct ZeroSizedFuture {} +impl Future for ZeroSizedFuture { + type Output = u32; + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Ready(5) + } +} + +#[test] +fn test_zero_sized() { + let fut = ZeroSizedFuture {}; + // Zero sized! + assert_eq!(Layout::for_value(&fut).size(), 0); + + let mut b = ReusableBoxFuture::new(fut); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); + + b.try_set(ZeroSizedFuture {}) + .unwrap_or_else(|_| panic!("incorrect size")); + + assert_eq!(b.get_pin().now_or_never(), Some(5)); + assert_eq!(b.get_pin().now_or_never(), Some(5)); +} diff --git a/tests/spawn_pinned.rs b/tests/spawn_pinned.rs new file mode 100644 index 0000000..9ea8cd2 --- /dev/null +++ b/tests/spawn_pinned.rs @@ -0,0 +1,237 @@ +#![warn(rust_2018_idioms)] +#![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads + +use std::rc::Rc; +use std::sync::Arc; +use tokio::sync::Barrier; +use tokio_util::task; + +/// Simple test of running a !Send future via spawn_pinned +#[tokio::test] +async fn can_spawn_not_send_future() { + let pool = task::LocalPoolHandle::new(1); + + let output = pool + .spawn_pinned(|| { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { local_data.to_string() } + }) + .await + .unwrap(); + + assert_eq!(output, "test"); +} + +/// Dropping the join handle still lets the task execute +#[test] +fn can_drop_future_and_still_get_output() { + let pool = task::LocalPoolHandle::new(1); + let (sender, receiver) = std::sync::mpsc::channel(); + + let _ = pool.spawn_pinned(move || { + // Rc is !Send + !Sync + let local_data = Rc::new("test"); + + // This future holds an Rc, so it is !Send + async move { + let _ = sender.send(local_data.to_string()); + } + }); + + assert_eq!(receiver.recv(), Ok("test".to_string())); +} + +#[test] +#[should_panic(expected = "assertion failed: pool_size > 0")] +fn cannot_create_zero_sized_pool() { + let _pool = task::LocalPoolHandle::new(0); +} + +/// We should be able to spawn multiple futures onto the pool at the same time. +#[tokio::test] +async fn can_spawn_multiple_futures() { + let pool = task::LocalPoolHandle::new(2); + + let join_handle1 = pool.spawn_pinned(|| { + let local_data = Rc::new("test1"); + async move { local_data.to_string() } + }); + let join_handle2 = pool.spawn_pinned(|| { + let local_data = Rc::new("test2"); + async move { local_data.to_string() } + }); + + assert_eq!(join_handle1.await.unwrap(), "test1"); + assert_eq!(join_handle2.await.unwrap(), "test2"); +} + +/// A panic in the spawned task causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn task_panic_propagates() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| async { + panic!("Test panic"); + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); + assert_eq!(*panic_str, "Test panic"); + + // Trying again with a "safe" task still works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// A panic during task creation causes the join handle to return an error. +/// But, you can continue to spawn tasks. +#[tokio::test] +async fn callback_panic_does_not_kill_worker() { + let pool = task::LocalPoolHandle::new(1); + + let join_handle = pool.spawn_pinned(|| { + panic!("Test panic"); + #[allow(unreachable_code)] + async {} + }); + + let result = join_handle.await; + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.is_panic()); + let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); + assert_eq!(*panic_str, "Test panic"); + + // Trying again with a "safe" callback works + let join_handle = pool.spawn_pinned(|| async { "test" }); + let result = join_handle.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "test"); +} + +/// Canceling the task via the returned join handle cancels the spawned task +/// (which has a different, internal join handle). +#[tokio::test] +async fn task_cancellation_propagates() { + let pool = task::LocalPoolHandle::new(1); + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); + let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); + let join_handle = pool.spawn_pinned(|| async move { + let _drop_sender = drop_sender; + // Move the Arc into the task + let _notify_dropped = notify_dropped; + let _ = start_sender.send(()); + + // Keep the task running until it gets aborted + futures::future::pending::<()>().await; + }); + + // Wait for the task to start + let _ = start_receiver.await; + + join_handle.abort(); + + // Wait for the inner task to abort, dropping the sender. + // The top level join handle aborts quicker than the inner task (the abort + // needs to propagate and get processed on the worker thread), so we can't + // just await the top level join handle. + let _ = drop_receiver.await; + + // Check that the Arc has been dropped. This verifies that the inner task + // was canceled as well. + assert!(weak_notify_dropped.upgrade().is_none()); +} + +/// Tasks should be given to the least burdened worker. When spawning two tasks +/// on a pool with two empty workers the tasks should be spawned on separate +/// workers. +#[tokio::test] +async fn tasks_are_balanced() { + let pool = task::LocalPoolHandle::new(2); + + // Spawn a task so one thread has a task count of 1 + let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); + let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); + let join_handle1 = pool.spawn_pinned(|| async move { + let _ = start_sender1.send(()); + let _ = end_receiver1.await; + std::thread::current().id() + }); + + // Wait for the first task to start up + let _ = start_receiver1.await; + + // This task should be spawned on the other thread + let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); + let join_handle2 = pool.spawn_pinned(|| async move { + let _ = start_sender2.send(()); + std::thread::current().id() + }); + + // Wait for the second task to start up + let _ = start_receiver2.await; + + // Allow the first task to end + let _ = end_sender1.send(()); + + let thread_id1 = join_handle1.await.unwrap(); + let thread_id2 = join_handle2.await.unwrap(); + + // Since the first task was active when the second task spawned, they should + // be on separate workers/threads. + assert_ne!(thread_id1, thread_id2); +} + +#[tokio::test] +async fn spawn_by_idx() { + let pool = task::LocalPoolHandle::new(3); + let barrier = Arc::new(Barrier::new(4)); + let barrier1 = barrier.clone(); + let barrier2 = barrier.clone(); + let barrier3 = barrier.clone(); + + let handle1 = pool.spawn_pinned_by_idx( + || async move { + barrier1.wait().await; + std::thread::current().id() + }, + 0, + ); + let _ = pool.spawn_pinned_by_idx( + || async move { + barrier2.wait().await; + std::thread::current().id() + }, + 0, + ); + let handle2 = pool.spawn_pinned_by_idx( + || async move { + barrier3.wait().await; + std::thread::current().id() + }, + 1, + ); + + let loads = pool.get_task_loads_for_each_worker(); + barrier.wait().await; + assert_eq!(loads[0], 2); + assert_eq!(loads[1], 1); + assert_eq!(loads[2], 0); + + let thread_id1 = handle1.await.unwrap(); + let thread_id2 = handle2.await.unwrap(); + + assert_ne!(thread_id1, thread_id2); +} diff --git a/tests/sync_cancellation_token.rs b/tests/sync_cancellation_token.rs new file mode 100644 index 0000000..279d74f --- /dev/null +++ b/tests/sync_cancellation_token.rs @@ -0,0 +1,447 @@ +#![warn(rust_2018_idioms)] + +use tokio::pin; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; + +use core::future::Future; +use core::task::{Context, Poll}; +use futures_test::task::new_count_waker; + +#[test] +fn cancel_token() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.cancelled(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.cancelled(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_token_owned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + assert!(!token.is_cancelled()); + + let wait_fut = token.clone().cancelled_owned(); + pin!(wait_fut); + + assert_eq!( + Poll::Pending, + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + let wait_fut_2 = token.clone().cancelled_owned(); + pin!(wait_fut_2); + + token.cancel(); + assert_eq!(wake_counter, 1); + assert!(token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + wait_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + wait_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_token_owned_drop_test() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let future = token.cancelled_owned(); + pin!(future); + + assert_eq!( + Poll::Pending, + future.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + // let future be dropped while pinned and under pending state to + // find potential memory related bugs. +} + +#[test] +fn cancel_child_token_through_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token = token.child_token(); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_grandchild_token_through_parent_if_child_was_dropped() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let intermediate_token = token.child_token(); + let child_token = intermediate_token.child_token(); + drop(intermediate_token); + assert!(!child_token.is_cancelled()); + + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + token.cancel(); + assert_eq!(wake_counter, 2); + assert!(token.is_cancelled()); + assert!(child_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn cancel_child_token_without_parent() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + + let child_token_1 = token.child_token(); + + let child_fut = child_token_1.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Pending, + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + child_token_1.cancel(); + assert_eq!(wake_counter, 1); + assert!(!token.is_cancelled()); + assert!(child_token_1.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + let child_token_2 = token.child_token(); + let child_fut_2 = child_token_2.cancelled(); + pin!(child_fut_2); + + assert_eq!( + Poll::Pending, + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + + token.cancel(); + assert_eq!(wake_counter, 3); + assert!(token.is_cancelled()); + assert!(child_token_2.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + child_fut_2.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); +} + +#[test] +fn create_child_token_after_parent_was_cancelled() { + for drop_child_first in [true, false].iter().cloned() { + let (waker, wake_counter) = new_count_waker(); + let token = CancellationToken::new(); + token.cancel(); + + let child_token = token.child_token(); + assert!(child_token.is_cancelled()); + + { + let child_fut = child_token.cancelled(); + pin!(child_fut); + let parent_fut = token.cancelled(); + pin!(parent_fut); + + assert_eq!( + Poll::Ready(()), + child_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + } + + if drop_child_first { + drop(child_token); + drop(token); + } else { + drop(token); + drop(child_token); + } + } +} + +#[test] +fn drop_multiple_child_tokens() { + for drop_first_child_first in &[true, false] { + let token = CancellationToken::new(); + let mut child_tokens = [None, None, None]; + for child in &mut child_tokens { + *child = Some(token.child_token()); + } + + assert!(!token.is_cancelled()); + assert!(!child_tokens[0].as_ref().unwrap().is_cancelled()); + + for i in 0..child_tokens.len() { + if *drop_first_child_first { + child_tokens[i] = None; + } else { + child_tokens[child_tokens.len() - 1 - i] = None; + } + assert!(!token.is_cancelled()); + } + + drop(token); + } +} + +#[test] +fn cancel_only_all_descendants() { + // ARRANGE + let (waker, wake_counter) = new_count_waker(); + + let parent_token = CancellationToken::new(); + let token = parent_token.child_token(); + let sibling_token = parent_token.child_token(); + let child1_token = token.child_token(); + let child2_token = token.child_token(); + let grandchild_token = child1_token.child_token(); + let grandchild2_token = child1_token.child_token(); + let great_grandchild_token = grandchild_token.child_token(); + + assert!(!parent_token.is_cancelled()); + assert!(!token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(!child1_token.is_cancelled()); + assert!(!child2_token.is_cancelled()); + assert!(!grandchild_token.is_cancelled()); + assert!(!grandchild2_token.is_cancelled()); + assert!(!great_grandchild_token.is_cancelled()); + + let parent_fut = parent_token.cancelled(); + let fut = token.cancelled(); + let sibling_fut = sibling_token.cancelled(); + let child1_fut = child1_token.cancelled(); + let child2_fut = child2_token.cancelled(); + let grandchild_fut = grandchild_token.cancelled(); + let grandchild2_fut = grandchild2_token.cancelled(); + let great_grandchild_fut = great_grandchild_token.cancelled(); + + pin!(parent_fut); + pin!(fut); + pin!(sibling_fut); + pin!(child1_fut); + pin!(child2_fut); + pin!(grandchild_fut); + pin!(grandchild2_fut); + pin!(great_grandchild_fut); + + assert_eq!( + Poll::Pending, + parent_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + sibling_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Pending, + great_grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 0); + + // ACT + token.cancel(); + + // ASSERT + assert_eq!(wake_counter, 6); + assert!(!parent_token.is_cancelled()); + assert!(token.is_cancelled()); + assert!(!sibling_token.is_cancelled()); + assert!(child1_token.is_cancelled()); + assert!(child2_token.is_cancelled()); + assert!(grandchild_token.is_cancelled()); + assert!(grandchild2_token.is_cancelled()); + assert!(great_grandchild_token.is_cancelled()); + + assert_eq!( + Poll::Ready(()), + fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child1_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + child2_fut.as_mut().poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + grandchild2_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!( + Poll::Ready(()), + great_grandchild_fut + .as_mut() + .poll(&mut Context::from_waker(&waker)) + ); + assert_eq!(wake_counter, 6); +} + +#[test] +fn drop_parent_before_child_tokens() { + let token = CancellationToken::new(); + let child1 = token.child_token(); + let child2 = token.child_token(); + + drop(token); + assert!(!child1.is_cancelled()); + + drop(child1); + drop(child2); +} + +#[test] +fn derives_send_sync() { + fn assert_send() {} + fn assert_sync() {} + + assert_send::(); + assert_sync::(); + + assert_send::>(); + assert_sync::>(); +} diff --git a/tests/task_join_map.rs b/tests/task_join_map.rs new file mode 100644 index 0000000..cef08b2 --- /dev/null +++ b/tests/task_join_map.rs @@ -0,0 +1,275 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "rt", tokio_unstable))] + +use tokio::sync::oneshot; +use tokio::time::Duration; +use tokio_util::task::JoinMap; + +use futures::future::FutureExt; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +#[tokio::test(start_paused = true)] +async fn test_with_sleep() { + let mut map = JoinMap::new(); + + for i in 0..10 { + map.spawn(i, async move { i }); + assert_eq!(map.len(), 1 + i); + } + map.detach_all(); + assert_eq!(map.len(), 0); + + assert!(matches!(map.join_next().await, None)); + + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + assert_eq!(map.len(), 1 + i); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_next().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_next().await, None)); + + // Do it again. + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_next().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_next().await, None)); +} + +#[tokio::test] +async fn test_abort_on_drop() { + let mut map = JoinMap::new(); + + let mut recvs = Vec::new(); + + for i in 0..16 { + let (send, recv) = oneshot::channel::<()>(); + recvs.push(recv); + + map.spawn(i, async { + // This task will never complete on its own. + futures::future::pending::<()>().await; + drop(send); + }); + } + + drop(map); + + for recv in recvs { + // The task is aborted soon and we will receive an error. + assert!(recv.await.is_err()); + } +} + +#[tokio::test] +async fn alternating() { + let mut map = JoinMap::new(); + + assert_eq!(map.len(), 0); + map.spawn(1, async {}); + assert_eq!(map.len(), 1); + map.spawn(2, async {}); + assert_eq!(map.len(), 2); + + for i in 0..16 { + let (_, res) = map.join_next().await.unwrap(); + assert!(res.is_ok()); + assert_eq!(map.len(), 1); + map.spawn(i, async {}); + assert_eq!(map.len(), 2); + } +} + +#[tokio::test(start_paused = true)] +async fn abort_by_key() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + for i in 0..16 { + if i % 2 != 0 { + // abort odd-numbered tasks. + map.abort(&i); + } + } + + while let Some((key, res)) = map.join_next().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[tokio::test(start_paused = true)] +async fn abort_by_predicate() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + // abort odd-numbered tasks. + map.abort_matching(|key| key % 2 != 0); + + while let Some((key, res)) = map.join_next().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[test] +fn runtime_gone() { + let mut map = JoinMap::new(); + { + let rt = rt(); + map.spawn_on("key", async { 1 }, rt.handle()); + drop(rt); + } + + let (key, res) = rt().block_on(map.join_next()).unwrap(); + assert_eq!(key, "key"); + assert!(res.unwrap_err().is_cancelled()); +} + +// This ensures that `join_next` works correctly when the coop budget is +// exhausted. +#[tokio::test(flavor = "current_thread")] +async fn join_map_coop() { + // Large enough to trigger coop. + const TASK_NUM: u32 = 1000; + + static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); + + let mut map = JoinMap::new(); + + for i in 0..TASK_NUM { + map.spawn(i, async move { + SEM.add_permits(1); + i + }); + } + + // Wait for all tasks to complete. + // + // Since this is a `current_thread` runtime, there's no race condition + // between the last permit being added and the task completing. + let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); + + let mut count = 0; + let mut coop_count = 0; + loop { + match map.join_next().now_or_never() { + Some(Some((key, Ok(i)))) => assert_eq!(key, i), + Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err), + None => { + coop_count += 1; + tokio::task::yield_now().await; + continue; + } + Some(None) => break, + } + + count += 1; + } + assert!(coop_count >= 1); + assert_eq!(count, TASK_NUM); +} + +#[tokio::test(start_paused = true)] +async fn abort_all() { + let mut map: JoinMap = JoinMap::new(); + + for i in 0..5 { + map.spawn(i, futures::future::pending()); + } + for i in 5..10 { + map.spawn(i, async { + tokio::time::sleep(Duration::from_secs(1)).await; + }); + } + + // The join map will now have 5 pending tasks and 5 ready tasks. + tokio::time::sleep(Duration::from_secs(2)).await; + + map.abort_all(); + assert_eq!(map.len(), 10); + + let mut count = 0; + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_next().await { + seen[k] = true; + if let Err(err) = res { + assert!(err.is_cancelled()); + } + count += 1; + } + assert_eq!(count, 10); + assert_eq!(map.len(), 0); + for was_seen in &seen { + assert!(was_seen); + } +} diff --git a/tests/time_delay_queue.rs b/tests/time_delay_queue.rs new file mode 100644 index 0000000..9ceae34 --- /dev/null +++ b/tests/time_delay_queue.rs @@ -0,0 +1,820 @@ +#![allow(clippy::disallowed_names)] +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::time::{self, sleep, sleep_until, Duration, Instant}; +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::time::DelayQueue; + +macro_rules! poll { + ($queue:ident) => { + $queue.enter(|cx, mut queue| queue.poll_expired(cx)) + }; +} + +macro_rules! assert_ready_some { + ($e:expr) => {{ + match assert_ready!($e) { + Some(v) => v, + None => panic!("None"), + } + }}; +} + +#[tokio::test] +async fn single_immediate_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now()); + + // Advance time by 1ms to handle thee rounding + sleep(ms(1)).await; + + assert_ready_some!(poll!(queue)); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn multi_immediate_delays() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let _k = queue.insert_at("1", Instant::now()); + let _k = queue.insert_at("2", Instant::now()); + let _k = queue.insert_at("3", Instant::now()); + + sleep(ms(1)).await; + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn single_short_delay() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let _key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(!queue.is_woken()); + + sleep(ms(5)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn multi_delay_at_start() { + time::pause(); + + let long = 262_144 + 9 * 4096; + let delays = &[1000, 2, 234, long, 60, 10]; + + let mut queue = task::spawn(DelayQueue::new()); + + // Setup the delays + for &i in delays { + let _key = queue.insert_at(i, Instant::now() + ms(i)); + } + + assert_pending!(poll!(queue)); + assert!(!queue.is_woken()); + + let start = Instant::now(); + for elapsed in 0..1200 { + println!("elapsed: {:?}", elapsed); + let elapsed = elapsed + 1; + tokio::time::sleep_until(start + ms(elapsed)).await; + + if delays.contains(&elapsed) { + assert!(queue.is_woken()); + assert_ready!(poll!(queue)); + assert_pending!(poll!(queue)); + } else if queue.is_woken() { + let cascade = &[192, 960]; + assert!( + cascade.contains(&elapsed), + "elapsed={} dt={:?}", + elapsed, + Instant::now() - start + ); + + assert_pending!(poll!(queue)); + } + } + println!("finished multi_delay_start"); +} + +#[tokio::test] +async fn insert_in_past_fires_immediately() { + println!("running insert_in_past_fires_immediately"); + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(10)).await; + + queue.insert_at("foo", now); + + assert_ready!(poll!(queue)); + println!("finished insert_in_past_fires_immediately"); +} + +#[tokio::test] +async fn remove_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let key = queue.insert_at("foo", Instant::now() + ms(5)); + + assert_pending!(poll!(queue)); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); + + sleep(ms(10)).await; + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn reset_entry() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + let key = queue.insert_at("foo", now + ms(5)); + + assert_pending!(poll!(queue)); + sleep(ms(1)).await; + + queue.reset_at(&key, now + ms(10)); + + assert_pending!(poll!(queue)); + + sleep(ms(7)).await; + + assert!(!queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_much_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(10)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +// Reproduces tokio-rs/tokio#849. +#[tokio::test] +async fn reset_twice() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + sleep(ms(1)).await; + + let key = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(3)).await; + + queue.reset_at(&key, now + ms(50)); + + sleep(ms(20)).await; + + queue.reset_at(&key, now + ms(40)); + + sleep(ms(20)).await; + + assert!(queue.is_woken()); +} + +/// Regression test: Given an entry inserted with a deadline in the past, so +/// that it is placed directly on the expired queue, reset the entry to a +/// deadline in the future. Validate that this leaves the entry and queue in an +/// internally consistent state by running an additional reset on the entry +/// before polling it to completion. +#[tokio::test] +async fn repeatedly_reset_entry_inserted_as_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + queue.reset_at(&key, now + ms(100)); + queue.reset_at(&key, now + ms(50)); + + assert_pending!(poll!(queue)); + + time::sleep_until(now + ms(60)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[tokio::test] +async fn remove_expired_item() { + time::pause(); + + let mut queue = DelayQueue::new(); + + let now = Instant::now(); + + sleep(ms(10)).await; + + let key = queue.insert_at("foo", now); + + let entry = queue.remove(&key); + assert_eq!(entry.into_inner(), "foo"); +} + +/// Regression test: it should be possible to remove entries which fall in the +/// 0th slot of the internal timer wheel — that is, entries whose expiration +/// (a) falls at the beginning of one of the wheel's hierarchical levels and (b) +/// is equal to the wheel's current elapsed time. +#[tokio::test] +async fn remove_at_timer_wheel_threshold() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo", now + ms(64)); + let key2 = queue.insert_at("bar", now + ms(64)); + + sleep(ms(80)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + + match entry { + "foo" => { + let entry = queue.remove(&key2).into_inner(); + assert_eq!(entry, "bar"); + } + "bar" => { + let entry = queue.remove(&key1).into_inner(); + assert_eq!(entry, "foo"); + } + other => panic!("other: {:?}", other), + } +} + +#[tokio::test] +async fn expires_before_last_insert() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(10_000)); + + // Delay should be set to 8.192s here. + assert_pending!(poll!(queue)); + + // Delay should be set to the delay of the new item here + queue.insert_at("bar", now + ms(600)); + + assert_pending!(poll!(queue)); + + sleep(ms(600)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn multi_reset() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + queue.reset_at(&two, now + ms(350)); + queue.reset_at(&one, now + ms(400)); + + sleep(ms(310)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "two"); + + assert_pending!(poll!(queue)); + + sleep(ms(50)).await; + + let entry = assert_ready_some!(poll!(queue)); + assert_eq!(*entry.get_ref(), "one"); + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()) +} + +#[tokio::test] +async fn expire_first_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "one"); +} + +#[tokio::test] +async fn expire_second_key_when_reset_to_expire_earlier() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("one", now + ms(200)); + let two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&two, now + ms(100)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn reset_first_expiring_item_to_expire_later() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let one = queue.insert_at("one", now + ms(200)); + let _two = queue.insert_at("two", now + ms(250)); + + assert_pending!(poll!(queue)); + + queue.reset_at(&one, now + ms(300)); + sleep(ms(250)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_before_first_after_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let _one = queue.insert_at("one", now + ms(200)); + + assert_pending!(poll!(queue)); + + let _two = queue.insert_at("two", now + ms(100)); + + sleep(ms(99)).await; + + assert_pending!(poll!(queue)); + + sleep(ms(1)).await; + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "two"); +} + +#[tokio::test] +async fn insert_after_ready_poll() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("1", now + ms(100)); + queue.insert_at("2", now + ms(100)); + queue.insert_at("3", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep(ms(100)).await; + + assert!(queue.is_woken()); + + let mut res = vec![]; + + while res.len() < 3 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + queue.insert_at("foo", now + ms(500)); + } + + res.sort_unstable(); + + assert_eq!("1", res[0]); + assert_eq!("2", res[1]); + assert_eq!("3", res[2]); +} + +#[tokio::test] +async fn reset_later_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(100)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn reset_inserted_expired() { + time::pause(); + let mut queue = task::spawn(DelayQueue::new()); + let now = Instant::now(); + + let key = queue.insert_at("foo", now - ms(100)); + + // this causes the panic described in #2473 + queue.reset_at(&key, now + ms(100)); + + assert_eq!(1, queue.len()); + + sleep(ms(200)).await; + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); + + assert_eq!(queue.len(), 0); +} + +#[tokio::test] +async fn reset_earlier_after_slot_starts() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let foo = queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(80)).await; + + assert!(!queue.is_woken()); + + // At this point the queue hasn't been polled, so `elapsed` on the wheel + // for the queue is still at 0 and hence the 1ms resolution slots cover + // [0-64). Resetting the time on the entry to 120 causes it to get put in + // the [64-128) slot. As the queue knows that the first entry is within + // that slot, but doesn't know when, it must wake immediately to advance + // the wheel. + queue.reset_at(&foo, now + ms(120)); + assert!(queue.is_woken()); + + assert_pending!(poll!(queue)); + + sleep_until(now + Duration::from_millis(119)).await; + assert!(!queue.is_woken()); + + sleep(ms(1)).await; + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "foo"); +} + +#[tokio::test] +async fn insert_in_past_after_poll_fires_immediately() { + time::pause(); + + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo", now + ms(200)); + + assert_pending!(poll!(queue)); + + sleep(ms(80)).await; + + assert!(!queue.is_woken()); + queue.insert_at("bar", now + ms(40)); + + assert!(queue.is_woken()); + + let entry = assert_ready_some!(poll!(queue)).into_inner(); + assert_eq!(entry, "bar"); +} + +#[tokio::test] +async fn delay_queue_poll_expired_when_empty() { + let mut delay_queue = task::spawn(DelayQueue::new()); + let key = delay_queue.insert(0, std::time::Duration::from_secs(10)); + assert_pending!(poll!(delay_queue)); + + delay_queue.remove(&key); + assert!(assert_ready!(poll!(delay_queue)).is_none()); +} + +#[tokio::test(start_paused = true)] +async fn compact_expire_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +async fn compact_remove_empty() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + let key1 = queue.insert_at("foo1", now + ms(10)); + let key2 = queue.insert_at("foo2", now + ms(10)); + + queue.remove(&key1); + queue.remove(&key2); + + queue.compact(); + + assert_eq!(queue.len(), 0); + assert_eq!(queue.capacity(), 0); +} + +#[tokio::test(start_paused = true)] +// Trigger a re-mapping of keys in the slab due to a `compact` call and +// test removal of re-mapped keys +async fn compact_remove_remapped_keys() { + let mut queue = task::spawn(DelayQueue::new()); + + let now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + let key3 = queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` will be assigned + // new indices here + queue.compact(); + + queue.insert_at("foo5", now + ms(10)); + + // test removal of re-mapped keys + let expired3 = queue.remove(&key3); + let expired4 = queue.remove(&key4); + + assert_eq!(expired3.into_inner(), "foo3"); + assert_eq!(expired4.into_inner(), "foo4"); + + queue.compact(); + assert_eq!(queue.len(), 1); + assert_eq!(queue.capacity(), 1); +} + +#[tokio::test(start_paused = true)] +async fn compact_change_deadline() { + let mut queue = task::spawn(DelayQueue::new()); + + let mut now = Instant::now(); + + queue.insert_at("foo1", now + ms(10)); + queue.insert_at("foo2", now + ms(10)); + + // should be assigned indices 3 and 4 + queue.insert_at("foo3", now + ms(20)); + let key4 = queue.insert_at("foo4", now + ms(20)); + + sleep(ms(10)).await; + + let mut res = vec![]; + while res.len() < 2 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + // items corresponding to `foo3` and `foo4` should be assigned + // new indices + queue.compact(); + + now = Instant::now(); + + queue.insert_at("foo5", now + ms(10)); + let key6 = queue.insert_at("foo6", now + ms(10)); + + queue.reset_at(&key4, now + ms(20)); + queue.reset_at(&key6, now + ms(20)); + + // foo3 and foo5 will expire + sleep(ms(10)).await; + + while res.len() < 4 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + sleep(ms(10)).await; + + while res.len() < 6 { + let entry = assert_ready_some!(poll!(queue)); + res.push(entry.into_inner()); + } + + let entry = assert_ready!(poll!(queue)); + assert!(entry.is_none()); +} + +#[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")] +#[tokio::test(start_paused = true)] +async fn remove_after_compact() { + let now = Instant::now(); + let mut queue = DelayQueue::new(); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + queue.remove(&foo_key); + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +#[cfg_attr(target_os = "wasi", ignore = "FIXME: Does not seem to work with WASI")] +#[tokio::test(start_paused = true)] +async fn remove_after_compact_poll() { + let now = Instant::now(); + let mut queue = task::spawn(DelayQueue::new()); + + let foo_key = queue.insert_at("foo", now + ms(10)); + queue.insert_at("bar", now + ms(20)); + + sleep(ms(10)).await; + assert_eq!(assert_ready_some!(poll!(queue)).key(), foo_key); + + queue.compact(); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + queue.remove(&foo_key); + })); + assert!(panic.is_err()); +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} diff --git a/tests/udp.rs b/tests/udp.rs new file mode 100644 index 0000000..1b99806 --- /dev/null +++ b/tests/udp.rs @@ -0,0 +1,133 @@ +#![warn(rust_2018_idioms)] +#![cfg(not(target_os = "wasi"))] // Wasi doesn't support UDP + +use tokio::net::UdpSocket; +use tokio_stream::StreamExt; +use tokio_util::codec::{Decoder, Encoder, LinesCodec}; +use tokio_util::udp::UdpFramed; + +use bytes::{BufMut, BytesMut}; +use futures::future::try_join; +use futures::future::FutureExt; +use futures::sink::SinkExt; +use std::io; +use std::sync::Arc; + +#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))] +#[tokio::test] +async fn send_framed_byte_codec() -> std::io::Result<()> { + let mut a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let mut b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + // test sending & receiving bytes + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b"4567"; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + + a_soc = a.into_inner(); + b_soc = b.into_inner(); + } + + #[cfg(not(any(target_os = "macos", target_os = "ios")))] + // test sending & receiving an empty message + { + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, ByteCodec); + + let msg = b""; + + let send = a.send((msg, b_addr)); + let recv = b.next().map(|e| e.unwrap()); + let (_, received) = try_join(send, recv).await.unwrap(); + + let (data, addr) = received; + assert_eq!(msg, &*data); + assert_eq!(a_addr, addr); + } + + Ok(()) +} + +pub struct ByteCodec; + +impl Decoder for ByteCodec { + type Item = Vec; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result>, io::Error> { + let len = buf.len(); + Ok(Some(buf.split_to(len).to_vec())) + } +} + +impl Encoder<&[u8]> for ByteCodec { + type Error = io::Error; + + fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put_slice(data); + Ok(()) + } +} + +#[tokio::test] +async fn send_framed_lines_codec() -> std::io::Result<()> { + let a_soc = UdpSocket::bind("127.0.0.1:0").await?; + let b_soc = UdpSocket::bind("127.0.0.1:0").await?; + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + Ok(()) +} + +#[tokio::test] +async fn framed_half() -> std::io::Result<()> { + let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?); + let b_soc = a_soc.clone(); + + let a_addr = a_soc.local_addr()?; + let b_addr = b_soc.local_addr()?; + + let mut a = UdpFramed::new(a_soc, ByteCodec); + let mut b = UdpFramed::new(b_soc, LinesCodec::new()); + + let msg = b"1\r\n2\r\n3\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + let msg = b"4\r\n5\r\n6\r\n".to_vec(); + a.send((&msg, b_addr)).await?; + + assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr)); + + assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr)); + assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr)); + + Ok(()) +} -- cgit v1.2.3