aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoel Galenson <jgalenson@google.com>2020-10-23 09:39:31 -0700
committerJoel Galenson <jgalenson@google.com>2020-10-23 09:52:09 -0700
commitd5495b03381a3ebe0805db353d198b285b535b5c (patch)
tree778b8524d15fca8b73db0253ee0e1919d0848bb6 /src
parentba45c5bedf31df8562364c61d3dfb5262f10642e (diff)
downloadtokio-d5495b03381a3ebe0805db353d198b285b535b5c.tar.gz
Update to tokio-0.3.1 and add new features
Test: Build Change-Id: I5b5b9b386a21982a019653d0cf0bd3afc505cfac
Diffstat (limited to 'src')
-rw-r--r--src/blocking.rs48
-rw-r--r--src/coop.rs21
-rw-r--r--src/fs/file.rs296
-rw-r--r--src/fs/mod.rs4
-rw-r--r--src/fs/open_options.rs3
-rw-r--r--src/fs/os/unix/dir_builder_ext.rs9
-rw-r--r--src/fs/os/unix/dir_entry_ext.rs44
-rw-r--r--src/fs/os/unix/mod.rs3
-rw-r--r--src/fs/os/unix/open_options_ext.rs12
-rw-r--r--src/fs/os/windows/mod.rs3
-rw-r--r--src/fs/os/windows/open_options_ext.rs214
-rw-r--r--src/fs/read_dir.rs14
-rw-r--r--src/future/block_on.rs15
-rw-r--r--src/future/mod.rs23
-rw-r--r--src/future/pending.rs44
-rw-r--r--src/future/poll_fn.rs2
-rw-r--r--src/future/try_join.rs2
-rw-r--r--src/io/async_read.rs162
-rw-r--r--src/io/async_seek.rs45
-rw-r--r--src/io/async_write.rs22
-rw-r--r--src/io/blocking.rs24
-rw-r--r--src/io/driver/mod.rs351
-rw-r--r--src/io/driver/ready.rs187
-rw-r--r--src/io/driver/scheduled_io.rs501
-rw-r--r--src/io/mod.rs38
-rw-r--r--src/io/poll_evented.rs337
-rw-r--r--src/io/read_buf.rs261
-rw-r--r--src/io/registration.rs286
-rw-r--r--src/io/seek.rs55
-rw-r--r--src/io/split.rs25
-rw-r--r--src/io/stderr.rs7
-rw-r--r--src/io/stdin.rs11
-rw-r--r--src/io/stdio_common.rs220
-rw-r--r--src/io/stdout.rs6
-rw-r--r--src/io/util/async_buf_read_ext.rs2
-rw-r--r--src/io/util/async_read_ext.rs14
-rw-r--r--src/io/util/async_seek_ext.rs108
-rw-r--r--src/io/util/async_write_ext.rs5
-rw-r--r--src/io/util/buf_reader.rs59
-rw-r--r--src/io/util/buf_stream.rs12
-rw-r--r--src/io/util/buf_writer.rs12
-rw-r--r--src/io/util/chain.rs24
-rw-r--r--src/io/util/copy.rs54
-rw-r--r--src/io/util/copy_buf.rs102
-rw-r--r--src/io/util/empty.rs11
-rw-r--r--src/io/util/flush.rs28
-rw-r--r--src/io/util/lines.rs3
-rw-r--r--src/io/util/mem.rs223
-rw-r--r--src/io/util/mod.rs13
-rw-r--r--src/io/util/read.rs34
-rw-r--r--src/io/util/read_buf.rs44
-rw-r--r--src/io/util/read_exact.rs46
-rw-r--r--src/io/util/read_int.rs48
-rw-r--r--src/io/util/read_line.rs86
-rw-r--r--src/io/util/read_to_end.rs153
-rw-r--r--src/io/util/read_to_string.rs96
-rw-r--r--src/io/util/read_until.rs34
-rw-r--r--src/io/util/repeat.rs16
-rw-r--r--src/io/util/shutdown.rs30
-rw-r--r--src/io/util/split.rs3
-rw-r--r--src/io/util/stream_reader.rs184
-rw-r--r--src/io/util/take.rs26
-rw-r--r--src/io/util/write.rs17
-rw-r--r--src/io/util/write_all.rs34
-rw-r--r--src/io/util/write_buf.rs25
-rw-r--r--src/io/util/write_int.rs18
-rw-r--r--src/lib.rs184
-rw-r--r--src/loom/mocked.rs27
-rw-r--r--src/loom/mod.rs2
-rw-r--r--src/loom/std/atomic_ptr.rs8
-rw-r--r--src/loom/std/atomic_u16.rs2
-rw-r--r--src/loom/std/atomic_u32.rs2
-rw-r--r--src/loom/std/atomic_u8.rs2
-rw-r--r--src/loom/std/atomic_usize.rs2
-rw-r--r--src/loom/std/mod.rs19
-rw-r--r--src/loom/std/mutex.rs31
-rw-r--r--src/loom/std/parking_lot.rs20
-rw-r--r--src/loom/std/unsafe_cell.rs2
-rw-r--r--src/macros/cfg.rs250
-rw-r--r--src/macros/mod.rs2
-rw-r--r--src/macros/scoped_tls.rs2
-rw-r--r--src/macros/select.rs30
-rw-r--r--src/macros/support.rs3
-rw-r--r--src/net/addr.rs112
-rw-r--r--src/net/lookup_host.rs6
-rw-r--r--src/net/mod.rs13
-rw-r--r--src/net/tcp/incoming.rs42
-rw-r--r--src/net/tcp/listener.rs192
-rw-r--r--src/net/tcp/mod.rs4
-rw-r--r--src/net/tcp/socket.rs349
-rw-r--r--src/net/tcp/split.rs28
-rw-r--r--src/net/tcp/split_owned.rs43
-rw-r--r--src/net/tcp/stream.rs512
-rw-r--r--src/net/udp/mod.rs4
-rw-r--r--src/net/udp/socket.rs432
-rw-r--r--src/net/udp/split.rs148
-rw-r--r--src/net/unix/datagram.rs242
-rw-r--r--src/net/unix/datagram/mod.rs3
-rw-r--r--src/net/unix/datagram/socket.rs731
-rw-r--r--src/net/unix/incoming.rs42
-rw-r--r--src/net/unix/listener.rs143
-rw-r--r--src/net/unix/mod.rs12
-rw-r--r--src/net/unix/socketaddr.rs31
-rw-r--r--src/net/unix/split.rs36
-rw-r--r--src/net/unix/split_owned.rs182
-rw-r--r--src/net/unix/stream.rs111
-rw-r--r--src/net/unix/ucred.rs16
-rw-r--r--src/park/either.rs9
-rw-r--r--src/park/mod.rs16
-rw-r--r--src/park/thread.rs199
-rw-r--r--src/process/mod.rs318
-rw-r--r--src/process/unix/driver.rs156
-rw-r--r--src/process/unix/mod.rs67
-rw-r--r--src/process/unix/orphan.rs78
-rw-r--r--src/process/unix/reap.rs59
-rw-r--r--src/process/windows.rs36
-rw-r--r--src/runtime/basic_scheduler.rs190
-rw-r--r--src/runtime/blocking/mod.rs21
-rw-r--r--src/runtime/blocking/pool.rs72
-rw-r--r--src/runtime/blocking/shutdown.rs11
-rw-r--r--src/runtime/builder.rs350
-rw-r--r--src/runtime/context.rs50
-rw-r--r--src/runtime/driver.rs205
-rw-r--r--src/runtime/enter.rs210
-rw-r--r--src/runtime/handle.rs385
-rw-r--r--src/runtime/io.rs63
-rw-r--r--src/runtime/mod.rs733
-rw-r--r--src/runtime/park.rs26
-rw-r--r--src/runtime/queue.rs10
-rw-r--r--src/runtime/shell.rs106
-rw-r--r--src/runtime/spawner.rs27
-rw-r--r--src/runtime/task/core.rs2
-rw-r--r--src/runtime/task/error.rs28
-rw-r--r--src/runtime/task/harness.rs8
-rw-r--r--src/runtime/task/join.rs110
-rw-r--r--src/runtime/task/mod.rs34
-rw-r--r--src/runtime/tests/loom_blocking.rs10
-rw-r--r--src/runtime/tests/loom_pool.rs11
-rw-r--r--src/runtime/tests/task.rs4
-rw-r--r--src/runtime/thread_pool/atomic_cell.rs1
-rw-r--r--src/runtime/thread_pool/idle.rs8
-rw-r--r--src/runtime/thread_pool/mod.rs10
-rw-r--r--src/runtime/thread_pool/worker.rs173
-rw-r--r--src/runtime/time.rs59
-rw-r--r--src/signal/registry.rs6
-rw-r--r--src/signal/unix.rs142
-rw-r--r--src/signal/unix/driver.rs207
-rw-r--r--src/signal/windows.rs30
-rw-r--r--src/stream/all.rs34
-rw-r--r--src/stream/any.rs34
-rw-r--r--src/stream/collect.rs141
-rw-r--r--src/stream/fold.rs6
-rw-r--r--src/stream/mod.rs121
-rw-r--r--src/stream/next.rs29
-rw-r--r--src/stream/stream_map.rs56
-rw-r--r--src/stream/throttle.rs (renamed from src/time/throttle.rs)28
-rw-r--r--src/stream/timeout.rs6
-rw-r--r--src/stream/try_next.rs26
-rw-r--r--src/sync/barrier.rs6
-rw-r--r--src/sync/batch_semaphore.rs70
-rw-r--r--src/sync/broadcast.rs478
-rw-r--r--src/sync/cancellation_token.rs861
-rw-r--r--src/sync/mod.rs62
-rw-r--r--src/sync/mpsc/block.rs8
-rw-r--r--src/sync/mpsc/bounded.rs465
-rw-r--r--src/sync/mpsc/chan.rs289
-rw-r--r--src/sync/mpsc/error.rs20
-rw-r--r--src/sync/mpsc/list.rs6
-rw-r--r--src/sync/mpsc/mod.rs39
-rw-r--r--src/sync/mpsc/unbounded.rs97
-rw-r--r--src/sync/mutex.rs62
-rw-r--r--src/sync/notify.rs152
-rw-r--r--src/sync/oneshot.rs4
-rw-r--r--src/sync/rwlock.rs431
-rw-r--r--src/sync/semaphore.rs15
-rw-r--r--src/sync/semaphore_ll.rs1221
-rw-r--r--src/sync/task/atomic_waker.rs5
-rw-r--r--src/sync/tests/loom_broadcast.rs2
-rw-r--r--src/sync/tests/loom_cancellation_token.rs155
-rw-r--r--src/sync/tests/loom_mpsc.rs71
-rw-r--r--src/sync/tests/loom_notify.rs12
-rw-r--r--src/sync/tests/loom_oneshot.rs6
-rw-r--r--src/sync/tests/loom_semaphore_ll.rs192
-rw-r--r--src/sync/tests/loom_watch.rs36
-rw-r--r--src/sync/tests/mod.rs5
-rw-r--r--src/sync/tests/semaphore_ll.rs470
-rw-r--r--src/sync/watch.rs384
-rw-r--r--src/task/blocking.rs141
-rw-r--r--src/task/local.rs22
-rw-r--r--src/task/mod.rs25
-rw-r--r--src/task/spawn.rs6
-rw-r--r--src/task/yield_now.rs2
-rw-r--r--src/time/clock.rs40
-rw-r--r--src/time/delay_queue.rs887
-rw-r--r--src/time/driver/atomic_stack.rs6
-rw-r--r--src/time/driver/entry.rs56
-rw-r--r--src/time/driver/handle.rs61
-rw-r--r--src/time/driver/mod.rs82
-rw-r--r--src/time/driver/registration.rs56
-rw-r--r--src/time/driver/stack.rs121
-rw-r--r--src/time/error.rs80
-rw-r--r--src/time/instant.rs40
-rw-r--r--src/time/interval.rs25
-rw-r--r--src/time/mod.rs36
-rw-r--r--src/time/sleep.rs (renamed from src/time/delay.rs)91
-rw-r--r--src/time/tests/mod.rs12
-rw-r--r--src/time/tests/test_sleep.rs (renamed from src/time/tests/test_delay.rs)120
-rw-r--r--src/time/timeout.rs41
-rw-r--r--src/time/wheel/level.rs174
-rw-r--r--src/time/wheel/mod.rs109
-rw-r--r--src/time/wheel/stack.rs120
-rw-r--r--src/util/bit.rs26
-rw-r--r--src/util/intrusive_double_linked_list.rs788
-rw-r--r--src/util/linked_list.rs171
-rw-r--r--src/util/mod.rs22
-rw-r--r--src/util/slab.rs841
-rw-r--r--src/util/slab/addr.rs154
-rw-r--r--src/util/slab/entry.rs7
-rw-r--r--src/util/slab/generation.rs32
-rw-r--r--src/util/slab/mod.rs107
-rw-r--r--src/util/slab/page.rs187
-rw-r--r--src/util/slab/shard.rs105
-rw-r--r--src/util/slab/slot.rs42
-rw-r--r--src/util/slab/stack.rs58
-rw-r--r--src/util/slab/tests/loom_slab.rs327
-rw-r--r--src/util/slab/tests/loom_stack.rs88
-rw-r--r--src/util/slab/tests/mod.rs2
-rw-r--r--src/util/trace.rs4
228 files changed, 11322 insertions, 13218 deletions
diff --git a/src/blocking.rs b/src/blocking.rs
new file mode 100644
index 0000000..f88b1db
--- /dev/null
+++ b/src/blocking.rs
@@ -0,0 +1,48 @@
+cfg_rt! {
+ pub(crate) use crate::runtime::spawn_blocking;
+ pub(crate) use crate::task::JoinHandle;
+}
+
+cfg_not_rt! {
+ use std::fmt;
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::{Context, Poll};
+
+ pub(crate) fn spawn_blocking<F, R>(_f: F) -> JoinHandle<R>
+ where
+ F: FnOnce() -> R + Send + 'static,
+ R: Send + 'static,
+ {
+ assert_send_sync::<JoinHandle<std::cell::Cell<()>>>();
+ panic!("requires the `rt` Tokio feature flag")
+
+ }
+
+ pub(crate) struct JoinHandle<R> {
+ _p: std::marker::PhantomData<R>,
+ }
+
+ unsafe impl<T: Send> Send for JoinHandle<T> {}
+ unsafe impl<T: Send> Sync for JoinHandle<T> {}
+
+ impl<R> Future for JoinHandle<R> {
+ type Output = Result<R, std::io::Error>;
+
+ fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
+ unreachable!()
+ }
+ }
+
+ impl<T> fmt::Debug for JoinHandle<T>
+ where
+ T: fmt::Debug,
+ {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("JoinHandle").finish()
+ }
+ }
+
+ fn assert_send_sync<T: Send + Sync>() {
+ }
+}
diff --git a/src/coop.rs b/src/coop.rs
index 27e969c..980cdf8 100644
--- a/src/coop.rs
+++ b/src/coop.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "full"), allow(dead_code))]
+
//! Opt-in yield points for improved cooperative scheduling.
//!
//! A single call to [`poll`] on a top-level task may potentially do a lot of
@@ -81,7 +83,7 @@ impl Budget {
}
}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
impl Budget {
fn has_remaining(self) -> bool {
self.0.map(|budget| budget > 0).unwrap_or(true)
@@ -96,14 +98,6 @@ pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R {
with_budget(Budget::initial(), f)
}
-cfg_rt_threaded! {
- /// Set the current task's budget
- #[cfg(feature = "blocking")]
- pub(crate) fn set(budget: Budget) {
- CURRENT.with(|cell| cell.set(budget))
- }
-}
-
#[inline(always)]
fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R {
struct ResetGuard<'a> {
@@ -128,14 +122,19 @@ fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R {
})
}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
+ /// Set the current task's budget
+ pub(crate) fn set(budget: Budget) {
+ CURRENT.with(|cell| cell.set(budget))
+ }
+
#[inline(always)]
pub(crate) fn has_budget_remaining() -> bool {
CURRENT.with(|cell| cell.get().has_remaining())
}
}
-cfg_blocking_impl! {
+cfg_rt! {
/// Forcibly remove the budgeting constraints early.
///
/// Returns the remaining budget
diff --git a/src/fs/file.rs b/src/fs/file.rs
index f3bc985..7c71f48 100644
--- a/src/fs/file.rs
+++ b/src/fs/file.rs
@@ -5,7 +5,8 @@
use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
-use crate::io::{AsyncRead, AsyncSeek, AsyncWrite};
+use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
+use crate::sync::Mutex;
use std::fmt;
use std::fs::{Metadata, Permissions};
@@ -80,12 +81,18 @@ use std::task::Poll::*;
/// ```
pub struct File {
std: Arc<sys::File>,
+ inner: Mutex<Inner>,
+}
+
+struct Inner {
state: State,
/// Errors from writes/flushes are returned in write/flush calls. If a write
/// error is observed while performing a read, it is saved until the next
/// write / flush call.
last_write_err: Option<io::ErrorKind>,
+
+ pos: u64,
}
#[derive(Debug)]
@@ -197,70 +204,11 @@ impl File {
pub fn from_std(std: sys::File) -> File {
File {
std: Arc::new(std),
- state: State::Idle(Some(Buf::with_capacity(0))),
- last_write_err: None,
- }
- }
-
- /// Seeks to an offset, in bytes, in a stream.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::fs::File;
- /// use tokio::prelude::*;
- ///
- /// use std::io::SeekFrom;
- ///
- /// # async fn dox() -> std::io::Result<()> {
- /// let mut file = File::open("foo.txt").await?;
- /// file.seek(SeekFrom::Start(6)).await?;
- ///
- /// let mut contents = vec![0u8; 10];
- /// file.read_exact(&mut contents).await?;
- /// # Ok(())
- /// # }
- /// ```
- ///
- /// The [`read_exact`] method is defined on the [`AsyncReadExt`] trait.
- ///
- /// [`read_exact`]: fn@crate::io::AsyncReadExt::read_exact
- /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
- pub async fn seek(&mut self, mut pos: SeekFrom) -> io::Result<u64> {
- self.complete_inflight().await;
-
- let mut buf = match self.state {
- Idle(ref mut buf_cell) => buf_cell.take().unwrap(),
- _ => unreachable!(),
- };
-
- // Factor in any unread data from the buf
- if !buf.is_empty() {
- let n = buf.discard_read();
-
- if let SeekFrom::Current(ref mut offset) = pos {
- *offset += n;
- }
- }
-
- let std = self.std.clone();
-
- // Start the operation
- self.state = Busy(sys::run(move || {
- let res = (&*std).seek(pos);
- (Operation::Seek(res), buf)
- }));
-
- let (op, buf) = match self.state {
- Idle(_) => unreachable!(),
- Busy(ref mut rx) => rx.await.unwrap(),
- };
-
- self.state = Idle(Some(buf));
-
- match op {
- Operation::Seek(res) => res,
- _ => unreachable!(),
+ inner: Mutex::new(Inner {
+ state: State::Idle(Some(Buf::with_capacity(0))),
+ last_write_err: None,
+ pos: 0,
+ }),
}
}
@@ -287,8 +235,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
- pub async fn sync_all(&mut self) -> io::Result<()> {
- self.complete_inflight().await;
+ pub async fn sync_all(&self) -> io::Result<()> {
+ let mut inner = self.inner.lock().await;
+ inner.complete_inflight().await;
let std = self.std.clone();
asyncify(move || std.sync_all()).await
@@ -321,8 +270,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
- pub async fn sync_data(&mut self) -> io::Result<()> {
- self.complete_inflight().await;
+ pub async fn sync_data(&self) -> io::Result<()> {
+ let mut inner = self.inner.lock().await;
+ inner.complete_inflight().await;
let std = self.std.clone();
asyncify(move || std.sync_data()).await
@@ -358,10 +308,11 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
- pub async fn set_len(&mut self, size: u64) -> io::Result<()> {
- self.complete_inflight().await;
+ pub async fn set_len(&self, size: u64) -> io::Result<()> {
+ let mut inner = self.inner.lock().await;
+ inner.complete_inflight().await;
- let mut buf = match self.state {
+ let mut buf = match inner.state {
Idle(ref mut buf_cell) => buf_cell.take().unwrap(),
_ => unreachable!(),
};
@@ -374,7 +325,7 @@ impl File {
let std = self.std.clone();
- self.state = Busy(sys::run(move || {
+ inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| std.set_len(size))
} else {
@@ -386,15 +337,17 @@ impl File {
(Operation::Seek(res), buf)
}));
- let (op, buf) = match self.state {
+ let (op, buf) = match inner.state {
Idle(_) => unreachable!(),
Busy(ref mut rx) => rx.await?,
};
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
match op {
- Operation::Seek(res) => res.map(|_| ()),
+ Operation::Seek(res) => res.map(|pos| {
+ inner.pos = pos;
+ }),
_ => unreachable!(),
}
}
@@ -459,7 +412,7 @@ impl File {
/// # }
/// ```
pub async fn into_std(mut self) -> sys::File {
- self.complete_inflight().await;
+ self.inner.get_mut().complete_inflight().await;
Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed")
}
@@ -526,42 +479,32 @@ impl File {
let std = self.std.clone();
asyncify(move || std.set_permissions(perm)).await
}
-
- async fn complete_inflight(&mut self) {
- use crate::future::poll_fn;
-
- if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
- self.last_write_err = Some(e.kind());
- }
- }
}
impl AsyncRead for File {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/fs.rs#L668
- false
- }
-
fn poll_read(
- mut self: Pin<&mut Self>,
+ self: Pin<&mut Self>,
cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let me = self.get_mut();
+ let inner = me.inner.get_mut();
+
loop {
- match self.state {
+ match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
*buf_cell = Some(buf);
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
buf.ensure_capacity_for(dst);
- let std = self.std.clone();
+ let std = me.std.clone();
- self.state = Busy(sys::run(move || {
+ inner.state = Busy(sys::run(move || {
let res = buf.read_from(&mut &*std);
(Operation::Read(res), buf)
}));
@@ -571,29 +514,32 @@ impl AsyncRead for File {
match op {
Operation::Read(Ok(_)) => {
- let n = buf.copy_to(dst);
- self.state = Idle(Some(buf));
- return Ready(Ok(n));
+ buf.copy_to(dst);
+ inner.state = Idle(Some(buf));
+ return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
return Ready(Err(e));
}
Operation::Write(Ok(_)) => {
assert!(buf.is_empty());
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
continue;
}
Operation::Write(Err(e)) => {
- assert!(self.last_write_err.is_none());
- self.last_write_err = Some(e.kind());
- self.state = Idle(Some(buf));
+ assert!(inner.last_write_err.is_none());
+ inner.last_write_err = Some(e.kind());
+ inner.state = Idle(Some(buf));
}
- Operation::Seek(_) => {
+ Operation::Seek(result) => {
assert!(buf.is_empty());
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
+ if let Ok(pos) = result {
+ inner.pos = pos;
+ }
continue;
}
}
@@ -604,13 +550,13 @@ impl AsyncRead for File {
}
impl AsyncSeek for File {
- fn start_seek(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- mut pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
+ fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> {
+ let me = self.get_mut();
+ let inner = me.inner.get_mut();
+
loop {
- match self.state {
+ match inner.state {
+ Busy(_) => panic!("must wait for poll_complete before calling start_seek"),
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
@@ -623,49 +569,41 @@ impl AsyncSeek for File {
}
}
- let std = self.std.clone();
+ let std = me.std.clone();
- self.state = Busy(sys::run(move || {
+ inner.state = Busy(sys::run(move || {
let res = (&*std).seek(pos);
(Operation::Seek(res), buf)
}));
-
- return Ready(Ok(()));
- }
- Busy(ref mut rx) => {
- let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
- self.state = Idle(Some(buf));
-
- match op {
- Operation::Read(_) => {}
- Operation::Write(Err(e)) => {
- assert!(self.last_write_err.is_none());
- self.last_write_err = Some(e.kind());
- }
- Operation::Write(_) => {}
- Operation::Seek(_) => {}
- }
+ return Ok(());
}
}
}
}
fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
+ let inner = self.inner.get_mut();
+
loop {
- match self.state {
- Idle(_) => panic!("must call start_seek before calling poll_complete"),
+ match inner.state {
+ Idle(_) => return Poll::Ready(Ok(inner.pos)),
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
match op {
Operation::Read(_) => {}
Operation::Write(Err(e)) => {
- assert!(self.last_write_err.is_none());
- self.last_write_err = Some(e.kind());
+ assert!(inner.last_write_err.is_none());
+ inner.last_write_err = Some(e.kind());
}
Operation::Write(_) => {}
- Operation::Seek(res) => return Ready(res),
+ Operation::Seek(res) => {
+ if let Ok(pos) = res {
+ inner.pos = pos;
+ }
+ return Ready(res);
+ }
}
}
}
@@ -675,16 +613,19 @@ impl AsyncSeek for File {
impl AsyncWrite for File {
fn poll_write(
- mut self: Pin<&mut Self>,
+ self: Pin<&mut Self>,
cx: &mut Context<'_>,
src: &[u8],
) -> Poll<io::Result<usize>> {
- if let Some(e) = self.last_write_err.take() {
+ let me = self.get_mut();
+ let inner = me.inner.get_mut();
+
+ if let Some(e) = inner.last_write_err.take() {
return Ready(Err(e.into()));
}
loop {
- match self.state {
+ match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
@@ -695,9 +636,9 @@ impl AsyncWrite for File {
};
let n = buf.copy_from(src);
- let std = self.std.clone();
+ let std = me.std.clone();
- self.state = Busy(sys::run(move || {
+ inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
@@ -711,7 +652,7 @@ impl AsyncWrite for File {
}
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
- self.state = Idle(Some(buf));
+ inner.state = Idle(Some(buf));
match op {
Operation::Read(_) => {
@@ -737,27 +678,12 @@ impl AsyncWrite for File {
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
- if let Some(e) = self.last_write_err.take() {
- return Ready(Err(e.into()));
- }
-
- let (op, buf) = match self.state {
- Idle(_) => return Ready(Ok(())),
- Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
- };
-
- // The buffer is not used here
- self.state = Idle(Some(buf));
-
- match op {
- Operation::Read(_) => Ready(Ok(())),
- Operation::Write(res) => Ready(res),
- Operation::Seek(_) => Ready(Ok(())),
- }
+ let inner = self.inner.get_mut();
+ inner.poll_flush(cx)
}
- fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
- Poll::Ready(Ok(()))
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ self.poll_flush(cx)
}
}
@@ -782,9 +708,53 @@ impl std::os::unix::io::AsRawFd for File {
}
}
+#[cfg(unix)]
+impl std::os::unix::io::FromRawFd for File {
+ unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self {
+ sys::File::from_raw_fd(fd).into()
+ }
+}
+
#[cfg(windows)]
impl std::os::windows::io::AsRawHandle for File {
fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
self.std.as_raw_handle()
}
}
+
+#[cfg(windows)]
+impl std::os::windows::io::FromRawHandle for File {
+ unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self {
+ sys::File::from_raw_handle(handle).into()
+ }
+}
+
+impl Inner {
+ async fn complete_inflight(&mut self) {
+ use crate::future::poll_fn;
+
+ if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
+ self.last_write_err = Some(e.kind());
+ }
+ }
+
+ fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ if let Some(e) = self.last_write_err.take() {
+ return Ready(Err(e.into()));
+ }
+
+ let (op, buf) = match self.state {
+ Idle(_) => return Ready(Ok(())),
+ Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
+ };
+
+ // The buffer is not used here
+ self.state = Idle(Some(buf));
+
+ match op {
+ Operation::Read(_) => Ready(Ok(())),
+ Operation::Write(res) => Ready(res),
+ Operation::Seek(_) => Ready(Ok(())),
+ }
+ }
+}
diff --git a/src/fs/mod.rs b/src/fs/mod.rs
index a2b062b..d2757a5 100644
--- a/src/fs/mod.rs
+++ b/src/fs/mod.rs
@@ -107,6 +107,6 @@ mod sys {
pub(crate) use std::fs::File;
// TODO: don't rename
- pub(crate) use crate::runtime::spawn_blocking as run;
- pub(crate) use crate::task::JoinHandle as Blocking;
+ pub(crate) use crate::blocking::spawn_blocking as run;
+ pub(crate) use crate::blocking::JoinHandle as Blocking;
}
diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs
index ba3d9a6..acd99a1 100644
--- a/src/fs/open_options.rs
+++ b/src/fs/open_options.rs
@@ -383,8 +383,7 @@ impl OpenOptions {
Ok(File::from_std(std))
}
- /// Returns a mutable reference to the the underlying std::fs::OpenOptions
- #[cfg(unix)]
+ /// Returns a mutable reference to the underlying `std::fs::OpenOptions`
pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions {
&mut self.0
}
diff --git a/src/fs/os/unix/dir_builder_ext.rs b/src/fs/os/unix/dir_builder_ext.rs
index e9a25b9..ccdc552 100644
--- a/src/fs/os/unix/dir_builder_ext.rs
+++ b/src/fs/os/unix/dir_builder_ext.rs
@@ -3,7 +3,7 @@ use crate::fs::dir_builder::DirBuilder;
/// Unix-specific extensions to [`DirBuilder`].
///
/// [`DirBuilder`]: crate::fs::DirBuilder
-pub trait DirBuilderExt {
+pub trait DirBuilderExt: sealed::Sealed {
/// Sets the mode to create new directories with.
///
/// This option defaults to 0o777.
@@ -27,3 +27,10 @@ impl DirBuilderExt for DirBuilder {
self
}
}
+
+impl sealed::Sealed for DirBuilder {}
+
+pub(crate) mod sealed {
+ #[doc(hidden)]
+ pub trait Sealed {}
+}
diff --git a/src/fs/os/unix/dir_entry_ext.rs b/src/fs/os/unix/dir_entry_ext.rs
new file mode 100644
index 0000000..2ac56da
--- /dev/null
+++ b/src/fs/os/unix/dir_entry_ext.rs
@@ -0,0 +1,44 @@
+use crate::fs::DirEntry;
+use std::os::unix::fs::DirEntryExt as _;
+
+/// Unix-specific extension methods for [`fs::DirEntry`].
+///
+/// This mirrors the definition of [`std::os::unix::fs::DirEntryExt`].
+///
+/// [`fs::DirEntry`]: crate::fs::DirEntry
+/// [`std::os::unix::fs::DirEntryExt`]: std::os::unix::fs::DirEntryExt
+pub trait DirEntryExt: sealed::Sealed {
+ /// Returns the underlying `d_ino` field in the contained `dirent`
+ /// structure.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::fs;
+ /// use tokio::fs::os::unix::DirEntryExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// let mut entries = fs::read_dir(".").await?;
+ /// while let Some(entry) = entries.next_entry().await? {
+ /// // Here, `entry` is a `DirEntry`.
+ /// println!("{:?}: {}", entry.file_name(), entry.ino());
+ /// }
+ /// # Ok(())
+ /// # }
+ /// ```
+ fn ino(&self) -> u64;
+}
+
+impl DirEntryExt for DirEntry {
+ fn ino(&self) -> u64 {
+ self.as_inner().ino()
+ }
+}
+
+impl sealed::Sealed for DirEntry {}
+
+pub(crate) mod sealed {
+ #[doc(hidden)]
+ pub trait Sealed {}
+}
diff --git a/src/fs/os/unix/mod.rs b/src/fs/os/unix/mod.rs
index 826222e..a0ae751 100644
--- a/src/fs/os/unix/mod.rs
+++ b/src/fs/os/unix/mod.rs
@@ -8,3 +8,6 @@ pub use self::open_options_ext::OpenOptionsExt;
mod dir_builder_ext;
pub use self::dir_builder_ext::DirBuilderExt;
+
+mod dir_entry_ext;
+pub use self::dir_entry_ext::DirEntryExt;
diff --git a/src/fs/os/unix/open_options_ext.rs b/src/fs/os/unix/open_options_ext.rs
index ff89275..6e0fd2b 100644
--- a/src/fs/os/unix/open_options_ext.rs
+++ b/src/fs/os/unix/open_options_ext.rs
@@ -1,14 +1,13 @@
use crate::fs::open_options::OpenOptions;
-use std::os::unix::fs::OpenOptionsExt as StdOpenOptionsExt;
+use std::os::unix::fs::OpenOptionsExt as _;
/// Unix-specific extensions to [`fs::OpenOptions`].
///
/// This mirrors the definition of [`std::os::unix::fs::OpenOptionsExt`].
///
-///
/// [`fs::OpenOptions`]: crate::fs::OpenOptions
/// [`std::os::unix::fs::OpenOptionsExt`]: std::os::unix::fs::OpenOptionsExt
-pub trait OpenOptionsExt {
+pub trait OpenOptionsExt: sealed::Sealed {
/// Sets the mode bits that a new file will be created with.
///
/// If a new file is created as part of an `OpenOptions::open` call then this
@@ -77,3 +76,10 @@ impl OpenOptionsExt for OpenOptions {
self
}
}
+
+impl sealed::Sealed for OpenOptions {}
+
+pub(crate) mod sealed {
+ #[doc(hidden)]
+ pub trait Sealed {}
+}
diff --git a/src/fs/os/windows/mod.rs b/src/fs/os/windows/mod.rs
index 42eb7bd..ab98c13 100644
--- a/src/fs/os/windows/mod.rs
+++ b/src/fs/os/windows/mod.rs
@@ -5,3 +5,6 @@ pub use self::symlink_dir::symlink_dir;
mod symlink_file;
pub use self::symlink_file::symlink_file;
+
+mod open_options_ext;
+pub use self::open_options_ext::OpenOptionsExt;
diff --git a/src/fs/os/windows/open_options_ext.rs b/src/fs/os/windows/open_options_ext.rs
new file mode 100644
index 0000000..ce86fba
--- /dev/null
+++ b/src/fs/os/windows/open_options_ext.rs
@@ -0,0 +1,214 @@
+use crate::fs::open_options::OpenOptions;
+use std::os::windows::fs::OpenOptionsExt as _;
+
+/// Unix-specific extensions to [`fs::OpenOptions`].
+///
+/// This mirrors the definition of [`std::os::windows::fs::OpenOptionsExt`].
+///
+/// [`fs::OpenOptions`]: crate::fs::OpenOptions
+/// [`std::os::windows::fs::OpenOptionsExt`]: std::os::windows::fs::OpenOptionsExt
+pub trait OpenOptionsExt: sealed::Sealed {
+ /// Overrides the `dwDesiredAccess` argument to the call to [`CreateFile`]
+ /// with the specified value.
+ ///
+ /// This will override the `read`, `write`, and `append` flags on the
+ /// `OpenOptions` structure. This method provides fine-grained control over
+ /// the permissions to read, write and append data, attributes (like hidden
+ /// and system), and extended attributes.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::fs::OpenOptions;
+ /// use tokio::fs::os::windows::OpenOptionsExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// // Open without read and write permission, for example if you only need
+ /// // to call `stat` on the file
+ /// let file = OpenOptions::new().access_mode(0).open("foo.txt").await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ fn access_mode(&mut self, access: u32) -> &mut Self;
+
+ /// Overrides the `dwShareMode` argument to the call to [`CreateFile`] with
+ /// the specified value.
+ ///
+ /// By default `share_mode` is set to
+ /// `FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE`. This allows
+ /// other processes to read, write, and delete/rename the same file
+ /// while it is open. Removing any of the flags will prevent other
+ /// processes from performing the corresponding operation until the file
+ /// handle is closed.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::fs::OpenOptions;
+ /// use tokio::fs::os::windows::OpenOptionsExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// // Do not allow others to read or modify this file while we have it open
+ /// // for writing.
+ /// let file = OpenOptions::new()
+ /// .write(true)
+ /// .share_mode(0)
+ /// .open("foo.txt").await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ fn share_mode(&mut self, val: u32) -> &mut Self;
+
+ /// Sets extra flags for the `dwFileFlags` argument to the call to
+ /// [`CreateFile2`] to the specified value (or combines it with
+ /// `attributes` and `security_qos_flags` to set the `dwFlagsAndAttributes`
+ /// for [`CreateFile`]).
+ ///
+ /// Custom flags can only set flags, not remove flags set by Rust's options.
+ /// This option overwrites any previously set custom flags.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use winapi::um::winbase::FILE_FLAG_DELETE_ON_CLOSE;
+ /// use tokio::fs::OpenOptions;
+ /// use tokio::fs::os::windows::OpenOptionsExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// let file = OpenOptions::new()
+ /// .create(true)
+ /// .write(true)
+ /// .custom_flags(FILE_FLAG_DELETE_ON_CLOSE)
+ /// .open("foo.txt").await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2
+ fn custom_flags(&mut self, flags: u32) -> &mut Self;
+
+ /// Sets the `dwFileAttributes` argument to the call to [`CreateFile2`] to
+ /// the specified value (or combines it with `custom_flags` and
+ /// `security_qos_flags` to set the `dwFlagsAndAttributes` for
+ /// [`CreateFile`]).
+ ///
+ /// If a _new_ file is created because it does not yet exist and
+ /// `.create(true)` or `.create_new(true)` are specified, the new file is
+ /// given the attributes declared with `.attributes()`.
+ ///
+ /// If an _existing_ file is opened with `.create(true).truncate(true)`, its
+ /// existing attributes are preserved and combined with the ones declared
+ /// with `.attributes()`.
+ ///
+ /// In all other cases the attributes get ignored.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use winapi::um::winnt::FILE_ATTRIBUTE_HIDDEN;
+ /// use tokio::fs::OpenOptions;
+ /// use tokio::fs::os::windows::OpenOptionsExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// let file = OpenOptions::new()
+ /// .write(true)
+ /// .create(true)
+ /// .attributes(FILE_ATTRIBUTE_HIDDEN)
+ /// .open("foo.txt").await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2
+ fn attributes(&mut self, val: u32) -> &mut Self;
+
+ /// Sets the `dwSecurityQosFlags` argument to the call to [`CreateFile2`] to
+ /// the specified value (or combines it with `custom_flags` and `attributes`
+ /// to set the `dwFlagsAndAttributes` for [`CreateFile`]).
+ ///
+ /// By default `security_qos_flags` is not set. It should be specified when
+ /// opening a named pipe, to control to which degree a server process can
+ /// act on behalf of a client process (security impersonation level).
+ ///
+ /// When `security_qos_flags` is not set, a malicious program can gain the
+ /// elevated privileges of a privileged Rust process when it allows opening
+ /// user-specified paths, by tricking it into opening a named pipe. So
+ /// arguably `security_qos_flags` should also be set when opening arbitrary
+ /// paths. However the bits can then conflict with other flags, specifically
+ /// `FILE_FLAG_OPEN_NO_RECALL`.
+ ///
+ /// For information about possible values, see [Impersonation Levels] on the
+ /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set
+ /// automatically when using this method.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use winapi::um::winbase::SECURITY_IDENTIFICATION;
+ /// use tokio::fs::OpenOptions;
+ /// use tokio::fs::os::windows::OpenOptionsExt;
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> std::io::Result<()> {
+ /// let file = OpenOptions::new()
+ /// .write(true)
+ /// .create(true)
+ ///
+ /// // Sets the flag value to `SecurityIdentification`.
+ /// .security_qos_flags(SECURITY_IDENTIFICATION)
+ ///
+ /// .open(r"\\.\pipe\MyPipe").await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ /// [`CreateFile2`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfile2
+ /// [Impersonation Levels]:
+ /// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level
+ fn security_qos_flags(&mut self, flags: u32) -> &mut Self;
+}
+
+impl OpenOptionsExt for OpenOptions {
+ fn access_mode(&mut self, access: u32) -> &mut OpenOptions {
+ self.as_inner_mut().access_mode(access);
+ self
+ }
+
+ fn share_mode(&mut self, share: u32) -> &mut OpenOptions {
+ self.as_inner_mut().share_mode(share);
+ self
+ }
+
+ fn custom_flags(&mut self, flags: u32) -> &mut OpenOptions {
+ self.as_inner_mut().custom_flags(flags);
+ self
+ }
+
+ fn attributes(&mut self, attributes: u32) -> &mut OpenOptions {
+ self.as_inner_mut().attributes(attributes);
+ self
+ }
+
+ fn security_qos_flags(&mut self, flags: u32) -> &mut OpenOptions {
+ self.as_inner_mut().security_qos_flags(flags);
+ self
+ }
+}
+
+impl sealed::Sealed for OpenOptions {}
+
+pub(crate) mod sealed {
+ #[doc(hidden)]
+ pub trait Sealed {}
+}
diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs
index f9b16c6..8ca583b 100644
--- a/src/fs/read_dir.rs
+++ b/src/fs/read_dir.rs
@@ -4,8 +4,6 @@ use std::ffi::OsString;
use std::fs::{FileType, Metadata};
use std::future::Future;
use std::io;
-#[cfg(unix)]
-use std::os::unix::fs::DirEntryExt;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
@@ -55,8 +53,7 @@ impl ReadDir {
poll_fn(|cx| self.poll_next_entry(cx)).await
}
- #[doc(hidden)]
- pub fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> {
+ fn poll_next_entry(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<Option<DirEntry>>> {
loop {
match self.0 {
State::Idle(ref mut std) => {
@@ -234,11 +231,10 @@ impl DirEntry {
let std = self.0.clone();
asyncify(move || std.file_type()).await
}
-}
-#[cfg(unix)]
-impl DirEntryExt for DirEntry {
- fn ino(&self) -> u64 {
- self.0.ino()
+ /// Returns a reference to the underlying `std::fs::DirEntry`
+ #[cfg(unix)]
+ pub(super) fn as_inner(&self) -> &std::fs::DirEntry {
+ &self.0
}
}
diff --git a/src/future/block_on.rs b/src/future/block_on.rs
new file mode 100644
index 0000000..91f9cc0
--- /dev/null
+++ b/src/future/block_on.rs
@@ -0,0 +1,15 @@
+use std::future::Future;
+
+cfg_rt! {
+ pub(crate) fn block_on<F: Future>(f: F) -> F::Output {
+ let mut e = crate::runtime::enter::enter(false);
+ e.block_on(f).unwrap()
+ }
+}
+
+cfg_not_rt! {
+ pub(crate) fn block_on<F: Future>(f: F) -> F::Output {
+ let mut park = crate::park::thread::CachedParkThread::new();
+ park.block_on(f).unwrap()
+ }
+}
diff --git a/src/future/mod.rs b/src/future/mod.rs
index 770753f..f7d93c9 100644
--- a/src/future/mod.rs
+++ b/src/future/mod.rs
@@ -1,15 +1,24 @@
-#![allow(unused_imports, dead_code)]
+#![cfg_attr(not(feature = "macros"), allow(unreachable_pub))]
//! Asynchronous values.
-mod maybe_done;
-pub use maybe_done::{maybe_done, MaybeDone};
+#[cfg(any(feature = "macros", feature = "process"))]
+pub(crate) mod maybe_done;
mod poll_fn;
pub use poll_fn::poll_fn;
-mod ready;
-pub(crate) use ready::{ok, Ready};
+cfg_not_loom! {
+ mod ready;
+ pub(crate) use ready::{ok, Ready};
+}
-mod try_join;
-pub(crate) use try_join::try_join3;
+cfg_process! {
+ mod try_join;
+ pub(crate) use try_join::try_join3;
+}
+
+cfg_sync! {
+ mod block_on;
+ pub(crate) use block_on::block_on;
+}
diff --git a/src/future/pending.rs b/src/future/pending.rs
deleted file mode 100644
index 287e836..0000000
--- a/src/future/pending.rs
+++ /dev/null
@@ -1,44 +0,0 @@
-use sdt::pin::Pin;
-use std::future::Future;
-use std::marker;
-use std::task::{Context, Poll};
-
-/// Future for the [`pending()`] function.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-struct Pending<T> {
- _data: marker::PhantomData<T>,
-}
-
-/// Creates a future which never resolves, representing a computation that never
-/// finishes.
-///
-/// The returned future will forever return [`Poll::Pending`].
-///
-/// # Examples
-///
-/// ```no_run
-/// use tokio::future;
-///
-/// #[tokio::main]
-/// async fn main {
-/// future::pending().await;
-/// unreachable!();
-/// }
-/// ```
-pub async fn pending() -> ! {
- Pending {
- _data: marker::PhantomData,
- }
- .await
-}
-
-impl<T> Future for Pending<T> {
- type Output = !;
-
- fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<T> {
- Poll::Pending
- }
-}
-
-impl<T> Unpin for Pending<T> {}
diff --git a/src/future/poll_fn.rs b/src/future/poll_fn.rs
index 9b3d137..0169bd5 100644
--- a/src/future/poll_fn.rs
+++ b/src/future/poll_fn.rs
@@ -1,3 +1,5 @@
+#![allow(dead_code)]
+
//! Definition of the `PollFn` adapter combinator
use std::fmt;
diff --git a/src/future/try_join.rs b/src/future/try_join.rs
index 5bd80dc..8943f61 100644
--- a/src/future/try_join.rs
+++ b/src/future/try_join.rs
@@ -1,4 +1,4 @@
-use crate::future::{maybe_done, MaybeDone};
+use crate::future::maybe_done::{maybe_done, MaybeDone};
use pin_project_lite::pin_project;
use std::future::Future;
diff --git a/src/io/async_read.rs b/src/io/async_read.rs
index 1aef415..d075443 100644
--- a/src/io/async_read.rs
+++ b/src/io/async_read.rs
@@ -1,6 +1,5 @@
-use bytes::BufMut;
+use super::ReadBuf;
use std::io;
-use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -16,9 +15,10 @@ use std::task::{Context, Poll};
/// Specifically, this means that the `poll_read` function will return one of
/// the following:
///
-/// * `Poll::Ready(Ok(n))` means that `n` bytes of data was immediately read
-/// and placed into the output buffer, where `n` == 0 implies that EOF has
-/// been reached.
+/// * `Poll::Ready(Ok(()))` means that data was immediately read and placed into
+/// the output buffer. The amount of data read can be determined by the
+/// increase in the length of the slice returned by `ReadBuf::filled`. If the
+/// difference is 0, EOF has been reached.
///
/// * `Poll::Pending` means that no data was read into the buffer
/// provided. The I/O object is not currently readable but may become readable
@@ -41,110 +41,29 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
- /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
- /// `true` if the supplied buffer was zeroed out.
- ///
- /// While it would be highly unusual, implementations of [`io::Read`] are
- /// able to read data from the buffer passed as an argument. Because of
- /// this, the buffer passed to [`io::Read`] must be initialized memory. In
- /// situations where large numbers of buffers are used, constantly having to
- /// zero out buffers can be expensive.
- ///
- /// This function does any necessary work to prepare an uninitialized buffer
- /// to be safe to pass to `read`. If `read` guarantees to never attempt to
- /// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
- /// doesn't need to do any work.
- ///
- /// If this function returns `true`, then the memory has been zeroed out.
- /// This allows implementations of `AsyncRead` which are composed of
- /// multiple subimplementations to efficiently implement
- /// `prepare_uninitialized_buffer`.
- ///
- /// This function isn't actually `unsafe` to call but `unsafe` to implement.
- /// The implementer must ensure that either the whole `buf` has been zeroed
- /// or `poll_read_buf()` overwrites the buffer without reading it and returns
- /// correct value.
- ///
- /// This function is called from [`poll_read_buf`].
- ///
- /// # Safety
- ///
- /// Implementations that return `false` must never read from data slices
- /// that they did not write to.
- ///
- /// [`io::Read`]: std::io::Read
- /// [`poll_read_buf`]: method@Self::poll_read_buf
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- for x in buf {
- *x = MaybeUninit::new(0);
- }
-
- true
- }
-
/// Attempts to read from the `AsyncRead` into `buf`.
///
- /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
+ /// On success, returns `Poll::Ready(Ok(()))` and fills `buf` with data
+ /// read. If no data was read (`buf.filled().is_empty()`) it implies that
+ /// EOF has been reached.
///
- /// If no data is available for reading, the method returns
- /// `Poll::Pending` and arranges for the current task (via
- /// `cx.waker()`) to receive a notification when the object becomes
- /// readable or is closed.
+ /// If no data is available for reading, the method returns `Poll::Pending`
+ /// and arranges for the current task (via `cx.waker()`) to receive a
+ /// notification when the object becomes readable or is closed.
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>>;
-
- /// Pulls some bytes from this source into the specified `BufMut`, returning
- /// how many bytes were read.
- ///
- /// The `buf` provided will have bytes read into it and the internal cursor
- /// will be advanced if any bytes were read. Note that this method typically
- /// will not reallocate the buffer provided.
- fn poll_read_buf<B: BufMut>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>>
- where
- Self: Sized,
- {
- if !buf.has_remaining_mut() {
- return Poll::Ready(Ok(0));
- }
-
- unsafe {
- let n = {
- let b = buf.bytes_mut();
-
- self.prepare_uninitialized_buffer(b);
-
- // Convert to `&mut [u8]`
- let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
-
- let n = ready!(self.poll_read(cx, b))?;
- assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
- n
- };
-
- buf.advance_mut(n);
- Poll::Ready(Ok(n))
- }
- }
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>>;
}
macro_rules! deref_async_read {
() => {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
@@ -163,43 +82,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
impl AsyncRead for &[u8] {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let amt = std::cmp::min(self.len(), buf.remaining());
+ let (a, b) = self.split_at(amt);
+ buf.put_slice(a);
+ *self = b;
+ Poll::Ready(Ok(()))
}
}
impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let pos = self.position();
+ let slice: &[u8] = (*self).get_ref().as_ref();
+
+ // The position could technically be out of bounds, so don't panic...
+ if pos > slice.len() as u64 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let start = pos as usize;
+ let amt = std::cmp::min(slice.len() - start, buf.remaining());
+ // Add won't overflow because of pos check above.
+ let end = start + amt;
+ buf.put_slice(&slice[start..end]);
+ self.set_position(end as u64);
+
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/async_seek.rs b/src/io/async_seek.rs
index 32ed0a2..bd7a992 100644
--- a/src/io/async_seek.rs
+++ b/src/io/async_seek.rs
@@ -23,36 +23,33 @@ pub trait AsyncSeek {
///
/// If this function returns successfully, then the job has been submitted.
/// To find out when it completes, call `poll_complete`.
- fn start_seek(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- position: SeekFrom,
- ) -> Poll<io::Result<()>>;
+ ///
+ /// # Errors
+ ///
+ /// This function can return [`io::ErrorKind::Other`] in case there is
+ /// another seek in progress. To avoid this, it is advisable that any call
+ /// to `start_seek` is preceded by a call to `poll_complete` to ensure all
+ /// pending seeks have completed.
+ fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()>;
/// Waits for a seek operation to complete.
///
/// If the seek operation completed successfully,
/// this method returns the new position from the start of the stream.
- /// That position can be used later with [`SeekFrom::Start`].
+ /// That position can be used later with [`SeekFrom::Start`]. Repeatedly
+ /// calling this function without calling `start_seek` might return the
+ /// same result.
///
/// # Errors
///
/// Seeking to a negative offset is considered an error.
- ///
- /// # Panics
- ///
- /// Calling this method without calling `start_seek` first is an error.
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>;
}
macro_rules! deref_async_seek {
() => {
- fn start_seek(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- Pin::new(&mut **self).start_seek(cx, pos)
+ fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ Pin::new(&mut **self).start_seek(pos)
}
fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
@@ -74,12 +71,8 @@ where
P: DerefMut + Unpin,
P::Target: AsyncSeek,
{
- fn start_seek(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- self.get_mut().as_mut().start_seek(cx, pos)
+ fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ self.get_mut().as_mut().start_seek(pos)
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
@@ -88,12 +81,8 @@ where
}
impl<T: AsRef<[u8]> + Unpin> AsyncSeek for io::Cursor<T> {
- fn start_seek(
- mut self: Pin<&mut Self>,
- _: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- Poll::Ready(io::Seek::seek(&mut *self, pos).map(drop))
+ fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ io::Seek::seek(&mut *self, pos).map(drop)
}
fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
Poll::Ready(Ok(self.get_mut().position()))
diff --git a/src/io/async_write.rs b/src/io/async_write.rs
index ecf7575..66ba4bf 100644
--- a/src/io/async_write.rs
+++ b/src/io/async_write.rs
@@ -1,4 +1,3 @@
-use bytes::Buf;
use std::io;
use std::ops::DerefMut;
use std::pin::Pin;
@@ -128,27 +127,6 @@ pub trait AsyncWrite {
/// This function will panic if not called within the context of a future's
/// task.
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;
-
- /// Writes a `Buf` into this value, returning how many bytes were written.
- ///
- /// Note that this method will advance the `buf` provided automatically by
- /// the number of bytes written.
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<Result<usize, io::Error>>
- where
- Self: Sized,
- {
- if !buf.has_remaining() {
- return Poll::Ready(Ok(0));
- }
-
- let n = ready!(self.poll_write(cx, buf.bytes()))?;
- buf.advance(n);
- Poll::Ready(Ok(n))
- }
}
macro_rules! deref_async_write {
diff --git a/src/io/blocking.rs b/src/io/blocking.rs
index 2491039..430801e 100644
--- a/src/io/blocking.rs
+++ b/src/io/blocking.rs
@@ -1,5 +1,5 @@
use crate::io::sys;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::cmp;
use std::future::Future;
@@ -53,17 +53,17 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
*buf_cell = Some(buf);
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
buf.ensure_capacity_for(dst);
@@ -80,9 +80,9 @@ where
match res {
Ok(_) => {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
self.state = Idle(Some(buf));
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
Err(e) => {
assert!(buf.is_empty());
@@ -203,9 +203,9 @@ impl Buf {
self.buf.len() - self.pos
}
- pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize {
- let n = cmp::min(self.len(), dst.len());
- dst[..n].copy_from_slice(&self.bytes()[..n]);
+ pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
+ let n = cmp::min(self.len(), dst.remaining());
+ dst.put_slice(&self.bytes()[..n]);
self.pos += n;
if self.pos == self.buf.len() {
@@ -229,10 +229,10 @@ impl Buf {
&self.buf[self.pos..]
}
- pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) {
+ pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
assert!(self.is_empty());
- let len = cmp::min(bytes.len(), MAX_BUF);
+ let len = cmp::min(bytes.remaining(), MAX_BUF);
if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs
index dbfb6e1..cd82b26 100644
--- a/src/io/driver/mod.rs
+++ b/src/io/driver/mod.rs
@@ -1,30 +1,38 @@
-pub(crate) mod platform;
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
+mod ready;
+use ready::Ready;
mod scheduled_io;
pub(crate) use scheduled_io::ScheduledIo; // pub(crate) for tests
-use crate::loom::sync::atomic::AtomicUsize;
use crate::park::{Park, Unpark};
-use crate::runtime::context;
-use crate::util::slab::{Address, Slab};
+use crate::util::bit;
+use crate::util::slab::{self, Slab};
-use mio::event::Evented;
use std::fmt;
use std::io;
-use std::sync::atomic::Ordering::SeqCst;
use std::sync::{Arc, Weak};
-use std::task::Waker;
use std::time::Duration;
/// I/O driver, backed by Mio
pub(crate) struct Driver {
+ /// Tracks the number of times `turn` is called. It is safe for this to wrap
+ /// as it is mostly used to determine when to call `compact()`
+ tick: u8,
+
/// Reuse the `mio::Events` value across calls to poll.
- events: mio::Events,
+ events: Option<mio::Events>,
+
+ /// Primary slab handle containing the state for each resource registered
+ /// with this driver.
+ resources: Slab<ScheduledIo>,
+
+ /// The system event queue
+ poll: mio::Poll,
/// State shared between the reactor and the handles.
inner: Arc<Inner>,
-
- _wakeup_registration: mio::Registration,
}
/// A reference to an I/O driver
@@ -33,18 +41,20 @@ pub(crate) struct Handle {
inner: Weak<Inner>,
}
-pub(super) struct Inner {
- /// The underlying system event queue.
- io: mio::Poll,
+pub(crate) struct ReadyEvent {
+ tick: u8,
+ ready: Ready,
+}
- /// Dispatch slabs for I/O and futures events
- pub(super) io_dispatch: Slab<ScheduledIo>,
+pub(super) struct Inner {
+ /// Registers I/O resources
+ registry: mio::Registry,
- /// The number of sources in `io_dispatch`.
- n_sources: AtomicUsize,
+ /// Allocates `ScheduledIo` handles when creating new resources.
+ pub(super) io_dispatch: slab::Allocator<ScheduledIo>,
/// Used to wake up the reactor from a call to `turn`
- wakeup: mio::SetReadiness,
+ waker: mio::Waker,
}
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
@@ -53,7 +63,24 @@ pub(super) enum Direction {
Write,
}
-const TOKEN_WAKEUP: mio::Token = mio::Token(Address::NULL);
+enum Tick {
+ Set(u8),
+ Clear(u8),
+}
+
+// TODO: Don't use a fake token. Instead, reserve a slot entry for the wakeup
+// token.
+const TOKEN_WAKEUP: mio::Token = mio::Token(1 << 31);
+
+const ADDRESS: bit::Pack = bit::Pack::least_significant(24);
+
+// Packs the generation value in the `readiness` field.
+//
+// The generation prevents a race condition where a slab slot is reused for a
+// new socket while the I/O driver is about to apply a readiness event. The
+// generaton value is checked when setting new readiness. If the generation do
+// not match, then the readiness event is discarded.
+const GENERATION: bit::Pack = ADDRESS.then(7);
fn _assert_kinds() {
fn _assert<T: Send + Sync>() {}
@@ -67,24 +94,22 @@ impl Driver {
/// Creates a new event loop, returning any error that happened during the
/// creation.
pub(crate) fn new() -> io::Result<Driver> {
- let io = mio::Poll::new()?;
- let wakeup_pair = mio::Registration::new2();
+ let poll = mio::Poll::new()?;
+ let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?;
+ let registry = poll.registry().try_clone()?;
- io.register(
- &wakeup_pair.0,
- TOKEN_WAKEUP,
- mio::Ready::readable(),
- mio::PollOpt::level(),
- )?;
+ let slab = Slab::new();
+ let allocator = slab.allocator();
Ok(Driver {
- events: mio::Events::with_capacity(1024),
- _wakeup_registration: wakeup_pair.0,
+ tick: 0,
+ events: Some(mio::Events::with_capacity(1024)),
+ resources: slab,
+ poll,
inner: Arc::new(Inner {
- io,
- io_dispatch: Slab::new(),
- n_sources: AtomicUsize::new(0),
- wakeup: wakeup_pair.1,
+ registry,
+ io_dispatch: allocator,
+ waker,
}),
})
}
@@ -102,65 +127,66 @@ impl Driver {
}
fn turn(&mut self, max_wait: Option<Duration>) -> io::Result<()> {
+ // How often to call `compact()` on the resource slab
+ const COMPACT_INTERVAL: u8 = 255;
+
+ self.tick = self.tick.wrapping_add(1);
+
+ if self.tick == COMPACT_INTERVAL {
+ self.resources.compact();
+ }
+
+ let mut events = self.events.take().expect("i/o driver event store missing");
+
// Block waiting for an event to happen, peeling out how many events
// happened.
- match self.inner.io.poll(&mut self.events, max_wait) {
+ match self.poll.poll(&mut events, max_wait) {
Ok(_) => {}
+ Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
// Process all the events that came in, dispatching appropriately
-
- for event in self.events.iter() {
+ for event in events.iter() {
let token = event.token();
- if token == TOKEN_WAKEUP {
- self.inner
- .wakeup
- .set_readiness(mio::Ready::empty())
- .unwrap();
- } else {
- self.dispatch(token, event.readiness());
+ if token != TOKEN_WAKEUP {
+ self.dispatch(token, Ready::from_mio(event));
}
}
+ self.events = Some(events);
+
Ok(())
}
- fn dispatch(&self, token: mio::Token, ready: mio::Ready) {
- let mut rd = None;
- let mut wr = None;
+ fn dispatch(&mut self, token: mio::Token, ready: Ready) {
+ let addr = slab::Address::from_usize(ADDRESS.unpack(token.0));
- let address = Address::from_usize(token.0);
-
- let io = match self.inner.io_dispatch.get(address) {
+ let io = match self.resources.get(addr) {
Some(io) => io,
None => return,
};
- if io
- .set_readiness(address, |curr| curr | ready.as_usize())
- .is_err()
- {
+ let res = io.set_readiness(Some(token.0), Tick::Set(self.tick), |curr| curr | ready);
+
+ if res.is_err() {
// token no longer valid!
return;
}
- if ready.is_writable() || platform::is_hup(ready) || platform::is_error(ready) {
- wr = io.writer.take_waker();
- }
-
- if !(ready & (!mio::Ready::writable())).is_empty() {
- rd = io.reader.take_waker();
- }
-
- if let Some(w) = rd {
- w.wake();
- }
+ io.wake(ready);
+ }
+}
- if let Some(w) = wr {
- w.wake();
- }
+impl Drop for Driver {
+ fn drop(&mut self) {
+ self.resources.for_each(|io| {
+ // If a task is waiting on the I/O resource, notify it. The task
+ // will then attempt to use the I/O resource and fail due to the
+ // driver being shutdown.
+ io.wake(Ready::ALL);
+ })
}
}
@@ -181,6 +207,8 @@ impl Park for Driver {
self.turn(Some(duration))?;
Ok(())
}
+
+ fn shutdown(&mut self) {}
}
impl fmt::Debug for Driver {
@@ -191,17 +219,36 @@ impl fmt::Debug for Driver {
// ===== impl Handle =====
-impl Handle {
- /// Returns a handle to the current reactor
- ///
- /// # Panics
- ///
- /// This function panics if there is no current reactor set.
- pub(super) fn current() -> Self {
- context::io_handle()
- .expect("there is no reactor running, must be called from the context of Tokio runtime")
+cfg_rt! {
+ impl Handle {
+ /// Returns a handle to the current reactor
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current reactor set and `rt` feature
+ /// flag is not enabled.
+ pub(super) fn current() -> Self {
+ crate::runtime::context::io_handle()
+ .expect("there is no reactor running, must be called from the context of Tokio runtime")
+ }
+ }
+}
+
+cfg_not_rt! {
+ impl Handle {
+ /// Returns a handle to the current reactor
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current reactor set, or if the `rt`
+ /// feature flag is not enabled.
+ pub(super) fn current() -> Self {
+ panic!("there is no reactor running, must be called from the context of Tokio runtime with `rt` enabled.")
+ }
}
+}
+impl Handle {
/// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise
/// makes the next call to `turn` return immediately.
///
@@ -213,7 +260,7 @@ impl Handle {
/// return immediately.
fn wakeup(&self) {
if let Some(inner) = self.inner() {
- inner.wakeup.set_readiness(mio::Ready::readable()).unwrap();
+ inner.waker.wake().expect("failed to wake I/O driver");
}
}
@@ -242,159 +289,35 @@ impl Inner {
/// The registration token is returned.
pub(super) fn add_source(
&self,
- source: &dyn Evented,
- ready: mio::Ready,
- ) -> io::Result<Address> {
- let address = self.io_dispatch.alloc().ok_or_else(|| {
+ source: &mut impl mio::event::Source,
+ interest: mio::Interest,
+ ) -> io::Result<slab::Ref<ScheduledIo>> {
+ let (address, shared) = self.io_dispatch.allocate().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"reactor at max registered I/O resources",
)
})?;
- self.n_sources.fetch_add(1, SeqCst);
+ let token = GENERATION.pack(shared.generation(), ADDRESS.pack(address.as_usize(), 0));
- self.io.register(
- source,
- mio::Token(address.to_usize()),
- ready,
- mio::PollOpt::edge(),
- )?;
+ self.registry
+ .register(source, mio::Token(token), interest)?;
- Ok(address)
+ Ok(shared)
}
/// Deregisters an I/O resource from the reactor.
- pub(super) fn deregister_source(&self, source: &dyn Evented) -> io::Result<()> {
- self.io.deregister(source)
- }
-
- pub(super) fn drop_source(&self, address: Address) {
- self.io_dispatch.remove(address);
- self.n_sources.fetch_sub(1, SeqCst);
- }
-
- /// Registers interest in the I/O resource associated with `token`.
- pub(super) fn register(&self, token: Address, dir: Direction, w: Waker) {
- let sched = self
- .io_dispatch
- .get(token)
- .unwrap_or_else(|| panic!("IO resource for token {:?} does not exist!", token));
-
- let waker = match dir {
- Direction::Read => &sched.reader,
- Direction::Write => &sched.writer,
- };
-
- waker.register(w);
+ pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> {
+ self.registry.deregister(source)
}
}
impl Direction {
- pub(super) fn mask(self) -> mio::Ready {
+ pub(super) fn mask(self) -> Ready {
match self {
- Direction::Read => {
- // Everything except writable is signaled through read.
- mio::Ready::all() - mio::Ready::writable()
- }
- Direction::Write => mio::Ready::writable() | platform::hup() | platform::error(),
- }
- }
-}
-
-#[cfg(all(test, loom))]
-mod tests {
- use super::*;
- use loom::thread;
-
- // No-op `Evented` impl just so we can have something to pass to `add_source`.
- struct NotEvented;
-
- impl Evented for NotEvented {
- fn register(
- &self,
- _: &mio::Poll,
- _: mio::Token,
- _: mio::Ready,
- _: mio::PollOpt,
- ) -> io::Result<()> {
- Ok(())
- }
-
- fn reregister(
- &self,
- _: &mio::Poll,
- _: mio::Token,
- _: mio::Ready,
- _: mio::PollOpt,
- ) -> io::Result<()> {
- Ok(())
- }
-
- fn deregister(&self, _: &mio::Poll) -> io::Result<()> {
- Ok(())
+ Direction::Read => Ready::READABLE | Ready::READ_CLOSED,
+ Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED,
}
}
-
- #[test]
- fn tokens_unique_when_dropped() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let thread = thread::spawn(move || {
- inner2.drop_source(token_1);
- });
-
- let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
-
- #[test]
- fn tokens_unique_when_dropped_on_full_page() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
- // add sources to fill up the first page so that the dropped index
- // may be reused.
- for _ in 0..31 {
- inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- }
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let thread = thread::spawn(move || {
- inner2.drop_source(token_1);
- });
-
- let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
-
- #[test]
- fn tokens_unique_concurrent_add() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
-
- let thread = thread::spawn(move || {
- let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap();
- token_2
- });
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let token_2 = thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
}
diff --git a/src/io/driver/ready.rs b/src/io/driver/ready.rs
new file mode 100644
index 0000000..8b556e9
--- /dev/null
+++ b/src/io/driver/ready.rs
@@ -0,0 +1,187 @@
+use std::fmt;
+use std::ops;
+
+const READABLE: usize = 0b0_01;
+const WRITABLE: usize = 0b0_10;
+const READ_CLOSED: usize = 0b0_0100;
+const WRITE_CLOSED: usize = 0b0_1000;
+
+/// A set of readiness event kinds.
+///
+/// `Ready` is set of operation descriptors indicating which kind of an
+/// operation is ready to be performed.
+///
+/// This struct only represents portable event kinds. Portable events are
+/// events that can be raised on any platform while guaranteeing no false
+/// positives.
+#[derive(Clone, Copy, PartialEq, PartialOrd)]
+pub(crate) struct Ready(usize);
+
+impl Ready {
+ /// Returns the empty `Ready` set.
+ pub(crate) const EMPTY: Ready = Ready(0);
+
+ /// Returns a `Ready` representing readable readiness.
+ pub(crate) const READABLE: Ready = Ready(READABLE);
+
+ /// Returns a `Ready` representing writable readiness.
+ pub(crate) const WRITABLE: Ready = Ready(WRITABLE);
+
+ /// Returns a `Ready` representing read closed readiness.
+ pub(crate) const READ_CLOSED: Ready = Ready(READ_CLOSED);
+
+ /// Returns a `Ready` representing write closed readiness.
+ pub(crate) const WRITE_CLOSED: Ready = Ready(WRITE_CLOSED);
+
+ /// Returns a `Ready` representing readiness for all operations.
+ pub(crate) const ALL: Ready = Ready(READABLE | WRITABLE | READ_CLOSED | WRITE_CLOSED);
+
+ pub(crate) fn from_mio(event: &mio::event::Event) -> Ready {
+ let mut ready = Ready::EMPTY;
+
+ if event.is_readable() {
+ ready |= Ready::READABLE;
+ }
+
+ if event.is_writable() {
+ ready |= Ready::WRITABLE;
+ }
+
+ if event.is_read_closed() {
+ ready |= Ready::READ_CLOSED;
+ }
+
+ if event.is_write_closed() {
+ ready |= Ready::WRITE_CLOSED;
+ }
+
+ ready
+ }
+
+ /// Returns true if `Ready` is the empty set
+ pub(crate) fn is_empty(self) -> bool {
+ self == Ready::EMPTY
+ }
+
+ /// Returns true if the value includes readable readiness
+ pub(crate) fn is_readable(self) -> bool {
+ self.contains(Ready::READABLE) || self.is_read_closed()
+ }
+
+ /// Returns true if the value includes writable readiness
+ pub(crate) fn is_writable(self) -> bool {
+ self.contains(Ready::WRITABLE) || self.is_write_closed()
+ }
+
+ /// Returns true if the value includes read closed readiness
+ pub(crate) fn is_read_closed(self) -> bool {
+ self.contains(Ready::READ_CLOSED)
+ }
+
+ /// Returns true if the value includes write closed readiness
+ pub(crate) fn is_write_closed(self) -> bool {
+ self.contains(Ready::WRITE_CLOSED)
+ }
+
+ /// Returns true if `self` is a superset of `other`.
+ ///
+ /// `other` may represent more than one readiness operations, in which case
+ /// the function only returns true if `self` contains all readiness
+ /// specified in `other`.
+ pub(crate) fn contains<T: Into<Self>>(self, other: T) -> bool {
+ let other = other.into();
+ (self & other) == other
+ }
+
+ /// Create a `Ready` instance using the given `usize` representation.
+ ///
+ /// The `usize` representation must have been obtained from a call to
+ /// `Readiness::as_usize`.
+ ///
+ /// This function is mainly provided to allow the caller to get a
+ /// readiness value from an `AtomicUsize`.
+ pub(crate) fn from_usize(val: usize) -> Ready {
+ Ready(val & Ready::ALL.as_usize())
+ }
+
+ /// Returns a `usize` representation of the `Ready` value.
+ ///
+ /// This function is mainly provided to allow the caller to store a
+ /// readiness value in an `AtomicUsize`.
+ pub(crate) fn as_usize(self) -> usize {
+ self.0
+ }
+}
+
+cfg_io_readiness! {
+ impl Ready {
+ pub(crate) fn from_interest(interest: mio::Interest) -> Ready {
+ let mut ready = Ready::EMPTY;
+
+ if interest.is_readable() {
+ ready |= Ready::READABLE;
+ ready |= Ready::READ_CLOSED;
+ }
+
+ if interest.is_writable() {
+ ready |= Ready::WRITABLE;
+ ready |= Ready::WRITE_CLOSED;
+ }
+
+ ready
+ }
+
+ pub(crate) fn intersection(self, interest: mio::Interest) -> Ready {
+ Ready(self.0 & Ready::from_interest(interest).0)
+ }
+
+ pub(crate) fn satisfies(self, interest: mio::Interest) -> bool {
+ self.0 & Ready::from_interest(interest).0 != 0
+ }
+ }
+}
+
+impl<T: Into<Ready>> ops::BitOr<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn bitor(self, other: T) -> Ready {
+ Ready(self.0 | other.into().0)
+ }
+}
+
+impl<T: Into<Ready>> ops::BitOrAssign<T> for Ready {
+ #[inline]
+ fn bitor_assign(&mut self, other: T) {
+ self.0 |= other.into().0;
+ }
+}
+
+impl<T: Into<Ready>> ops::BitAnd<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn bitand(self, other: T) -> Ready {
+ Ready(self.0 & other.into().0)
+ }
+}
+
+impl<T: Into<Ready>> ops::Sub<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn sub(self, other: T) -> Ready {
+ Ready(self.0 & !other.into().0)
+ }
+}
+
+impl fmt::Debug for Ready {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Ready")
+ .field("is_readable", &self.is_readable())
+ .field("is_writable", &self.is_writable())
+ .field("is_read_closed", &self.is_read_closed())
+ .field("is_write_closed", &self.is_write_closed())
+ .finish()
+ }
+}
diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs
index 7f6446e..b1354a0 100644
--- a/src/io/driver/scheduled_io.rs
+++ b/src/io/driver/scheduled_io.rs
@@ -1,47 +1,109 @@
-use crate::loom::future::AtomicWaker;
+use super::{Direction, Ready, ReadyEvent, Tick};
use crate::loom::sync::atomic::AtomicUsize;
+use crate::loom::sync::Mutex;
use crate::util::bit;
-use crate::util::slab::{Address, Entry, Generation};
+use crate::util::slab::Entry;
-use std::sync::atomic::Ordering::{AcqRel, Acquire, SeqCst};
+use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
+use std::task::{Context, Poll, Waker};
+cfg_io_readiness! {
+ use crate::util::linked_list::{self, LinkedList};
+
+ use std::cell::UnsafeCell;
+ use std::future::Future;
+ use std::marker::PhantomPinned;
+ use std::pin::Pin;
+ use std::ptr::NonNull;
+}
+
+/// Stored in the I/O driver resource slab.
#[derive(Debug)]
pub(crate) struct ScheduledIo {
+ /// Packs the resource's readiness with the resource's generation.
readiness: AtomicUsize,
- pub(crate) reader: AtomicWaker,
- pub(crate) writer: AtomicWaker,
+
+ waiters: Mutex<Waiters>,
}
-const PACK: bit::Pack = bit::Pack::most_significant(Generation::WIDTH);
+cfg_io_readiness! {
+ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
+}
-impl Entry for ScheduledIo {
- fn generation(&self) -> Generation {
- unpack_generation(self.readiness.load(SeqCst))
+#[derive(Debug, Default)]
+struct Waiters {
+ #[cfg(feature = "net")]
+ /// List of all current waiters
+ list: WaitList,
+
+ /// Waker used for AsyncRead
+ reader: Option<Waker>,
+
+ /// Waker used for AsyncWrite
+ writer: Option<Waker>,
+}
+
+cfg_io_readiness! {
+ #[derive(Debug)]
+ struct Waiter {
+ pointers: linked_list::Pointers<Waiter>,
+
+ /// The waker for this task
+ waker: Option<Waker>,
+
+ /// The interest this waiter is waiting on
+ interest: mio::Interest,
+
+ is_ready: bool,
+
+ /// Should never be `!Unpin`
+ _p: PhantomPinned,
}
- fn reset(&self, generation: Generation) -> bool {
- let mut current = self.readiness.load(Acquire);
+ /// Future returned by `readiness()`
+ struct Readiness<'a> {
+ scheduled_io: &'a ScheduledIo,
- loop {
- if unpack_generation(current) != generation {
- return false;
- }
+ state: State,
- let next = PACK.pack(generation.next().to_usize(), 0);
+ /// Entry in the waiter `LinkedList`.
+ waiter: UnsafeCell<Waiter>,
+ }
- match self
- .readiness
- .compare_exchange(current, next, AcqRel, Acquire)
- {
- Ok(_) => break,
- Err(actual) => current = actual,
- }
- }
+ enum State {
+ Init,
+ Waiting,
+ Done,
+ }
+}
+
+// The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness.
+//
+// | reserved | generation | driver tick | readinesss |
+// |----------+------------+--------------+------------|
+// | 1 bit | 7 bits + 8 bits + 16 bits |
+
+const READINESS: bit::Pack = bit::Pack::least_significant(16);
+
+const TICK: bit::Pack = READINESS.then(8);
- drop(self.reader.take_waker());
- drop(self.writer.take_waker());
+const GENERATION: bit::Pack = TICK.then(7);
- true
+#[test]
+fn test_generations_assert_same() {
+ assert_eq!(super::GENERATION, GENERATION);
+}
+
+// ===== impl ScheduledIo =====
+
+impl Entry for ScheduledIo {
+ fn reset(&self) {
+ let state = self.readiness.load(Acquire);
+
+ let generation = GENERATION.unpack(state);
+ let next = GENERATION.pack_lossy(generation + 1, 0);
+
+ self.readiness.store(next, Release);
}
}
@@ -49,31 +111,14 @@ impl Default for ScheduledIo {
fn default() -> ScheduledIo {
ScheduledIo {
readiness: AtomicUsize::new(0),
- reader: AtomicWaker::new(),
- writer: AtomicWaker::new(),
+ waiters: Mutex::new(Default::default()),
}
}
}
impl ScheduledIo {
- #[cfg(all(test, loom))]
- /// Returns the current readiness value of this `ScheduledIo`, if the
- /// provided `token` is still a valid access.
- ///
- /// # Returns
- ///
- /// If the given token's generation no longer matches the `ScheduledIo`'s
- /// generation, then the corresponding IO resource has been removed and
- /// replaced with a new resource. In that case, this method returns `None`.
- /// Otherwise, this returns the current readiness.
- pub(crate) fn get_readiness(&self, address: Address) -> Option<usize> {
- let ready = self.readiness.load(Acquire);
-
- if unpack_generation(ready) != address.generation() {
- return None;
- }
-
- Some(ready & !PACK.mask())
+ pub(crate) fn generation(&self) -> usize {
+ GENERATION.unpack(self.readiness.load(Acquire))
}
/// Sets the readiness on this `ScheduledIo` by invoking the given closure on
@@ -81,6 +126,8 @@ impl ScheduledIo {
///
/// # Arguments
/// - `token`: the token for this `ScheduledIo`.
+ /// - `tick`: whether setting the tick or trying to clear readiness for a
+ /// specific tick.
/// - `f`: a closure returning a new readiness value given the previous
/// readiness.
///
@@ -90,52 +137,354 @@ impl ScheduledIo {
/// generation, then the corresponding IO resource has been removed and
/// replaced with a new resource. In that case, this method returns `Err`.
/// Otherwise, this returns the previous readiness.
- pub(crate) fn set_readiness(
+ pub(super) fn set_readiness(
&self,
- address: Address,
- f: impl Fn(usize) -> usize,
- ) -> Result<usize, ()> {
- let generation = address.generation();
-
+ token: Option<usize>,
+ tick: Tick,
+ f: impl Fn(Ready) -> Ready,
+ ) -> Result<(), ()> {
let mut current = self.readiness.load(Acquire);
loop {
- // Check that the generation for this access is still the current
- // one.
- if unpack_generation(current) != generation {
- return Err(());
+ let current_generation = GENERATION.unpack(current);
+
+ if let Some(token) = token {
+ // Check that the generation for this access is still the
+ // current one.
+ if GENERATION.unpack(token) != current_generation {
+ return Err(());
+ }
}
- // Mask out the generation bits so that the modifying function
- // doesn't see them.
- let current_readiness = current & mio::Ready::all().as_usize();
+
+ // Mask out the tick/generation bits so that the modifying
+ // function doesn't see them.
+ let current_readiness = Ready::from_usize(current);
let new = f(current_readiness);
- debug_assert!(
- new <= !PACK.max_value(),
- "new readiness value would overwrite generation bits!"
- );
-
- match self.readiness.compare_exchange(
- current,
- PACK.pack(generation.to_usize(), new),
- AcqRel,
- Acquire,
- ) {
- Ok(_) => return Ok(current),
+ let packed = match tick {
+ Tick::Set(t) => TICK.pack(t as usize, new.as_usize()),
+ Tick::Clear(t) => {
+ if TICK.unpack(current) as u8 != t {
+ // Trying to clear readiness with an old event!
+ return Err(());
+ }
+
+ TICK.pack(t as usize, new.as_usize())
+ }
+ };
+
+ let next = GENERATION.pack(current_generation, packed);
+
+ match self
+ .readiness
+ .compare_exchange(current, next, AcqRel, Acquire)
+ {
+ Ok(_) => return Ok(()),
// we lost the race, retry!
Err(actual) => current = actual,
}
}
}
+
+ /// Notifies all pending waiters that have registered interest in `ready`.
+ ///
+ /// There may be many waiters to notify. Waking the pending task **must** be
+ /// done from outside of the lock otherwise there is a potential for a
+ /// deadlock.
+ ///
+ /// A stack array of wakers is created and filled with wakers to notify, the
+ /// lock is released, and the wakers are notified. Because there may be more
+ /// than 32 wakers to notify, if the stack array fills up, the lock is
+ /// released, the array is cleared, and the iteration continues.
+ pub(super) fn wake(&self, ready: Ready) {
+ const NUM_WAKERS: usize = 32;
+
+ let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default();
+ let mut curr = 0;
+
+ let mut waiters = self.waiters.lock();
+
+ // check for AsyncRead slot
+ if ready.is_readable() {
+ if let Some(waker) = waiters.reader.take() {
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+
+ // check for AsyncWrite slot
+ if ready.is_writable() {
+ if let Some(waker) = waiters.writer.take() {
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+
+ #[cfg(feature = "net")]
+ 'outer: loop {
+ let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest));
+
+ while curr < NUM_WAKERS {
+ match iter.next() {
+ Some(waiter) => {
+ let waiter = unsafe { &mut *waiter.as_ptr() };
+
+ if let Some(waker) = waiter.waker.take() {
+ waiter.is_ready = true;
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+ None => {
+ break 'outer;
+ }
+ }
+ }
+
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
+ }
+
+ curr = 0;
+
+ // Acquire the lock again.
+ waiters = self.waiters.lock();
+ }
+
+ // Release the lock before notifying
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
+ }
+ }
+
+ /// Poll version of checking readiness for a certain direction.
+ ///
+ /// These are to support `AsyncRead` and `AsyncWrite` polling methods,
+ /// which cannot use the `async fn` version. This uses reserved reader
+ /// and writer slots.
+ pub(in crate::io) fn poll_readiness(
+ &self,
+ cx: &mut Context<'_>,
+ direction: Direction,
+ ) -> Poll<ReadyEvent> {
+ let curr = self.readiness.load(Acquire);
+
+ let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr));
+
+ if ready.is_empty() {
+ // Update the task info
+ let mut waiters = self.waiters.lock();
+ let slot = match direction {
+ Direction::Read => &mut waiters.reader,
+ Direction::Write => &mut waiters.writer,
+ };
+ *slot = Some(cx.waker().clone());
+
+ // Try again, in case the readiness was changed while we were
+ // taking the waiters lock
+ let curr = self.readiness.load(Acquire);
+ let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr));
+ if ready.is_empty() {
+ Poll::Pending
+ } else {
+ Poll::Ready(ReadyEvent {
+ tick: TICK.unpack(curr) as u8,
+ ready,
+ })
+ }
+ } else {
+ Poll::Ready(ReadyEvent {
+ tick: TICK.unpack(curr) as u8,
+ ready,
+ })
+ }
+ }
+
+ pub(crate) fn clear_readiness(&self, event: ReadyEvent) {
+ // This consumes the current readiness state **except** for closed
+ // states. Closed states are excluded because they are final states.
+ let mask_no_closed = event.ready - Ready::READ_CLOSED - Ready::WRITE_CLOSED;
+
+ // result isn't important
+ let _ = self.set_readiness(None, Tick::Clear(event.tick), |curr| curr - mask_no_closed);
+ }
}
impl Drop for ScheduledIo {
fn drop(&mut self) {
- self.writer.wake();
- self.reader.wake();
+ self.wake(Ready::ALL);
}
}
-fn unpack_generation(src: usize) -> Generation {
- Generation::new(PACK.unpack(src))
+unsafe impl Send for ScheduledIo {}
+unsafe impl Sync for ScheduledIo {}
+
+cfg_io_readiness! {
+ impl ScheduledIo {
+ /// An async version of `poll_readiness` which uses a linked list of wakers
+ pub(crate) async fn readiness(&self, interest: mio::Interest) -> ReadyEvent {
+ self.readiness_fut(interest).await
+ }
+
+ // This is in a separate function so that the borrow checker doesn't think
+ // we are borrowing the `UnsafeCell` possibly over await boundaries.
+ //
+ // Go figure.
+ fn readiness_fut(&self, interest: mio::Interest) -> Readiness<'_> {
+ Readiness {
+ scheduled_io: self,
+ state: State::Init,
+ waiter: UnsafeCell::new(Waiter {
+ pointers: linked_list::Pointers::new(),
+ waker: None,
+ is_ready: false,
+ interest,
+ _p: PhantomPinned,
+ }),
+ }
+ }
+ }
+
+ unsafe impl linked_list::Link for Waiter {
+ type Handle = NonNull<Waiter>;
+ type Target = Waiter;
+
+ fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
+ *handle
+ }
+
+ unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
+ ptr
+ }
+
+ unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
+ NonNull::from(&mut target.as_mut().pointers)
+ }
+ }
+
+ // ===== impl Readiness =====
+
+ impl Future for Readiness<'_> {
+ type Output = ReadyEvent;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ use std::sync::atomic::Ordering::SeqCst;
+
+ let (scheduled_io, state, waiter) = unsafe {
+ let me = self.get_unchecked_mut();
+ (&me.scheduled_io, &mut me.state, &me.waiter)
+ };
+
+ loop {
+ match *state {
+ State::Init => {
+ // Optimistically check existing readiness
+ let curr = scheduled_io.readiness.load(SeqCst);
+ let ready = Ready::from_usize(READINESS.unpack(curr));
+
+ // Safety: `waiter.interest` never changes
+ let interest = unsafe { (*waiter.get()).interest };
+ let ready = ready.intersection(interest);
+
+ if !ready.is_empty() {
+ // Currently ready!
+ let tick = TICK.unpack(curr) as u8;
+ *state = State::Done;
+ return Poll::Ready(ReadyEvent { ready, tick });
+ }
+
+ // Wasn't ready, take the lock (and check again while locked).
+ let mut waiters = scheduled_io.waiters.lock();
+
+ let curr = scheduled_io.readiness.load(SeqCst);
+ let ready = Ready::from_usize(READINESS.unpack(curr));
+ let ready = ready.intersection(interest);
+
+ if !ready.is_empty() {
+ // Currently ready!
+ let tick = TICK.unpack(curr) as u8;
+ *state = State::Done;
+ return Poll::Ready(ReadyEvent { ready, tick });
+ }
+
+ // Not ready even after locked, insert into list...
+
+ // Safety: called while locked
+ unsafe {
+ (*waiter.get()).waker = Some(cx.waker().clone());
+ }
+
+ // Insert the waiter into the linked list
+ //
+ // safety: pointers from `UnsafeCell` are never null.
+ waiters
+ .list
+ .push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
+ *state = State::Waiting;
+ }
+ State::Waiting => {
+ // Currently in the "Waiting" state, implying the caller has
+ // a waiter stored in the waiter list (guarded by
+ // `notify.waiters`). In order to access the waker fields,
+ // we must hold the lock.
+
+ let waiters = scheduled_io.waiters.lock();
+
+ // Safety: called while locked
+ let w = unsafe { &mut *waiter.get() };
+
+ if w.is_ready {
+ // Our waker has been notified.
+ *state = State::Done;
+ } else {
+ // Update the waker, if necessary.
+ if !w.waker.as_ref().unwrap().will_wake(cx.waker()) {
+ w.waker = Some(cx.waker().clone());
+ }
+
+ return Poll::Pending;
+ }
+
+ // Explicit drop of the lock to indicate the scope that the
+ // lock is held. Because holding the lock is required to
+ // ensure safe access to fields not held within the lock, it
+ // is helpful to visualize the scope of the critical
+ // section.
+ drop(waiters);
+ }
+ State::Done => {
+ let tick = TICK.unpack(scheduled_io.readiness.load(Acquire)) as u8;
+
+ // Safety: State::Done means it is no longer shared
+ let w = unsafe { &mut *waiter.get() };
+
+ return Poll::Ready(ReadyEvent {
+ tick,
+ ready: Ready::from_interest(w.interest),
+ });
+ }
+ }
+ }
+ }
+ }
+
+ impl Drop for Readiness<'_> {
+ fn drop(&mut self) {
+ let mut waiters = self.scheduled_io.waiters.lock();
+
+ // Safety: `waiter` is only ever stored in `waiters`
+ unsafe {
+ waiters
+ .list
+ .remove(NonNull::new_unchecked(self.waiter.get()))
+ };
+ }
+ }
+
+ unsafe impl Send for Readiness<'_> {}
+ unsafe impl Sync for Readiness<'_> {}
}
diff --git a/src/io/mod.rs b/src/io/mod.rs
index 7b00556..9191bbc 100644
--- a/src/io/mod.rs
+++ b/src/io/mod.rs
@@ -162,8 +162,8 @@
//!
//! # `std` re-exports
//!
-//! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported
-//! from `std::io` for ease of use.
+//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are
+//! re-exported from `std::io` for ease of use.
//!
//! [`AsyncRead`]: trait@AsyncRead
//! [`AsyncWrite`]: trait@AsyncWrite
@@ -176,6 +176,7 @@
//! [`ErrorKind`]: enum@ErrorKind
//! [`Result`]: type@Result
//! [`Read`]: std::io::Read
+//! [`SeekFrom`]: enum@SeekFrom
//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html
//! [`Stream`]: crate::stream::Stream
//! [`Write`]: std::io::Write
@@ -187,7 +188,6 @@ mod async_buf_read;
pub use self::async_buf_read::AsyncBufRead;
mod async_read;
-
pub use self::async_read::AsyncRead;
mod async_seek;
@@ -196,17 +196,27 @@ pub use self::async_seek::AsyncSeek;
mod async_write;
pub use self::async_write::AsyncWrite;
+mod read_buf;
+pub use self::read_buf::ReadBuf;
+
+// Re-export some types from `std::io` so that users don't have to deal
+// with conflicts when `use`ing `tokio::io` and `std::io`.
+#[doc(no_inline)]
+pub use std::io::{Error, ErrorKind, Result, SeekFrom};
+
cfg_io_driver! {
pub(crate) mod driver;
mod poll_evented;
- pub use poll_evented::PollEvented;
+ #[cfg(not(loom))]
+ pub(crate) use poll_evented::PollEvented;
mod registration;
- pub use registration::Registration;
}
cfg_io_std! {
+ mod stdio_common;
+
mod stderr;
pub use stderr::{stderr, Stderr};
@@ -222,21 +232,11 @@ cfg_io_util! {
pub use split::{split, ReadHalf, WriteHalf};
pub(crate) mod seek;
- pub use self::seek::Seek;
-
pub(crate) mod util;
pub use util::{
- copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
- BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take,
+ copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
+ BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take,
};
-
- cfg_stream! {
- pub use util::{stream_reader, StreamReader};
- }
-
- // Re-export io::Error so that users don't have to deal with conflicts when
- // `use`ing `tokio::io` and `std::io`.
- pub use std::io::{Error, ErrorKind, Result};
}
cfg_not_io_util! {
@@ -249,7 +249,7 @@ cfg_io_blocking! {
/// Types in this module can be mocked out in tests.
mod sys {
// TODO: don't rename
- pub(crate) use crate::runtime::spawn_blocking as run;
- pub(crate) use crate::task::JoinHandle as Blocking;
+ pub(crate) use crate::blocking::spawn_blocking as run;
+ pub(crate) use crate::blocking::JoinHandle as Blocking;
}
}
diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs
index 5295bd7..66a2634 100644
--- a/src/io/poll_evented.rs
+++ b/src/io/poll_evented.rs
@@ -1,13 +1,12 @@
-use crate::io::driver::platform;
-use crate::io::{AsyncRead, AsyncWrite, Registration};
+use crate::io::driver::{Direction, Handle, ReadyEvent};
+use crate::io::registration::Registration;
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
-use mio::event::Evented;
+use mio::event::Source;
use std::fmt;
use std::io::{self, Read, Write};
use std::marker::Unpin;
use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
-use std::sync::atomic::Ordering::Relaxed;
use std::task::{Context, Poll};
cfg_io_driver! {
@@ -53,37 +52,6 @@ cfg_io_driver! {
/// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and
/// [`clear_read_ready`].
///
- /// ```rust
- /// use tokio::io::PollEvented;
- ///
- /// use futures::ready;
- /// use mio::Ready;
- /// use mio::net::{TcpStream, TcpListener};
- /// use std::io;
- /// use std::task::{Context, Poll};
- ///
- /// struct MyListener {
- /// poll_evented: PollEvented<TcpListener>,
- /// }
- ///
- /// impl MyListener {
- /// pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<TcpStream, io::Error>> {
- /// let ready = Ready::readable();
- ///
- /// ready!(self.poll_evented.poll_read_ready(cx, ready))?;
- ///
- /// match self.poll_evented.get_ref().accept() {
- /// Ok((socket, _)) => Poll::Ready(Ok(socket)),
- /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- /// self.poll_evented.clear_read_ready(cx, ready)?;
- /// Poll::Pending
- /// }
- /// Err(e) => Poll::Ready(Err(e)),
- /// }
- /// }
- /// }
- /// ```
- ///
/// ## Platform-specific events
///
/// `PollEvented` also allows receiving platform-specific `mio::Ready` events.
@@ -101,70 +69,15 @@ cfg_io_driver! {
/// [`clear_write_ready`]: method@Self::clear_write_ready
/// [`poll_read_ready`]: method@Self::poll_read_ready
/// [`poll_write_ready`]: method@Self::poll_write_ready
- pub struct PollEvented<E: Evented> {
+ pub(crate) struct PollEvented<E: Source> {
io: Option<E>,
- inner: Inner,
+ registration: Registration,
}
}
-struct Inner {
- registration: Registration,
-
- /// Currently visible read readiness
- read_readiness: AtomicUsize,
-
- /// Currently visible write readiness
- write_readiness: AtomicUsize,
-}
-
// ===== impl PollEvented =====
-macro_rules! poll_ready {
- ($me:expr, $mask:expr, $cache:ident, $take:ident, $poll:expr) => {{
- // Load cached & encoded readiness.
- let mut cached = $me.inner.$cache.load(Relaxed);
- let mask = $mask | platform::hup() | platform::error();
-
- // See if the current readiness matches any bits.
- let mut ret = mio::Ready::from_usize(cached) & $mask;
-
- if ret.is_empty() {
- // Readiness does not match, consume the registration's readiness
- // stream. This happens in a loop to ensure that the stream gets
- // drained.
- loop {
- let ready = match $poll? {
- Poll::Ready(v) => v,
- Poll::Pending => return Poll::Pending,
- };
- cached |= ready.as_usize();
-
- // Update the cache store
- $me.inner.$cache.store(cached, Relaxed);
-
- ret |= ready & mask;
-
- if !ret.is_empty() {
- return Poll::Ready(Ok(ret));
- }
- }
- } else {
- // Check what's new with the registration stream. This will not
- // request to be notified
- if let Some(ready) = $me.inner.registration.$take()? {
- cached |= ready.as_usize();
- $me.inner.$cache.store(cached, Relaxed);
- }
-
- Poll::Ready(Ok(mio::Ready::from_usize(cached)))
- }
- }};
-}
-
-impl<E> PollEvented<E>
-where
- E: Evented,
-{
+impl<E: Source> PollEvented<E> {
/// Creates a new `PollEvented` associated with the default reactor.
///
/// # Panics
@@ -173,71 +86,57 @@ where
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new(io: E) -> io::Result<Self> {
- PollEvented::new_with_ready(io, mio::Ready::all())
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ #[cfg_attr(feature = "signal", allow(unused))]
+ pub(crate) fn new(io: E) -> io::Result<Self> {
+ PollEvented::new_with_interest(io, mio::Interest::READABLE | mio::Interest::WRITABLE)
}
- /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready`
- /// state. `new_with_ready` should be used over `new` when you need control over the readiness
+ /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Interest`
+ /// state. `new_with_interest` should be used over `new` when you need control over the readiness
/// state, such as when a file descriptor only allows reads. This does not add `hup` or `error`
/// so if you are interested in those states, you will need to add them to the readiness state
/// passed to this function.
///
- /// An example to listen to read only
- ///
- /// ```rust
- /// ##[cfg(unix)]
- /// mio::Ready::from_usize(
- /// mio::Ready::readable().as_usize()
- /// | mio::unix::UnixReady::error().as_usize()
- /// | mio::unix::UnixReady::hup().as_usize()
- /// );
- /// ```
- ///
/// # Panics
///
/// This function panics if thread-local runtime is not set.
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result<Self> {
- let registration = Registration::new_with_ready(&io, ready)?;
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ #[cfg_attr(feature = "signal", allow(unused))]
+ pub(crate) fn new_with_interest(io: E, interest: mio::Interest) -> io::Result<Self> {
+ Self::new_with_interest_and_handle(io, interest, Handle::current())
+ }
+
+ pub(crate) fn new_with_interest_and_handle(
+ mut io: E,
+ interest: mio::Interest,
+ handle: Handle,
+ ) -> io::Result<Self> {
+ let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?;
Ok(Self {
io: Some(io),
- inner: Inner {
- registration,
- read_readiness: AtomicUsize::new(0),
- write_readiness: AtomicUsize::new(0),
- },
+ registration,
})
}
/// Returns a shared reference to the underlying I/O object this readiness
/// stream is wrapping.
- pub fn get_ref(&self) -> &E {
+ #[cfg(any(feature = "net", feature = "process", feature = "signal"))]
+ pub(crate) fn get_ref(&self) -> &E {
self.io.as_ref().unwrap()
}
/// Returns a mutable reference to the underlying I/O object this readiness
/// stream is wrapping.
- pub fn get_mut(&mut self) -> &mut E {
+ pub(crate) fn get_mut(&mut self) -> &mut E {
self.io.as_mut().unwrap()
}
- /// Consumes self, returning the inner I/O object
- ///
- /// This function will deregister the I/O resource from the reactor before
- /// returning. If the deregistration operation fails, an error is returned.
- ///
- /// Note that deregistering does not guarantee that the I/O resource can be
- /// registered with a different reactor. Some I/O resource types can only be
- /// associated with a single reactor instance for their lifetime.
- pub fn into_inner(mut self) -> io::Result<E> {
- let io = self.io.take().unwrap();
- self.inner.registration.deregister(&io)?;
- Ok(io)
+ pub(crate) fn clear_readiness(&self, event: ReadyEvent) {
+ self.registration.clear_readiness(event);
}
/// Checks the I/O resource's read readiness state.
@@ -266,51 +165,8 @@ where
///
/// This method may not be called concurrently. It takes `&self` to allow
/// calling it concurrently with `poll_write_ready`.
- pub fn poll_read_ready(
- &self,
- cx: &mut Context<'_>,
- mask: mio::Ready,
- ) -> Poll<io::Result<mio::Ready>> {
- assert!(!mask.is_writable(), "cannot poll for write readiness");
- poll_ready!(
- self,
- mask,
- read_readiness,
- take_read_ready,
- self.inner.registration.poll_read_ready(cx)
- )
- }
-
- /// Clears the I/O resource's read readiness state and registers the current
- /// task to be notified once a read readiness event is received.
- ///
- /// After calling this function, `poll_read_ready` will return
- /// `Poll::Pending` until a new read readiness event has been received.
- ///
- /// The `mask` argument specifies the readiness bits to clear. This may not
- /// include `writable` or `hup`.
- ///
- /// # Panics
- ///
- /// This function panics if:
- ///
- /// * `ready` includes writable or HUP
- /// * called from outside of a task context.
- pub fn clear_read_ready(&self, cx: &mut Context<'_>, ready: mio::Ready) -> io::Result<()> {
- // Cannot clear write readiness
- assert!(!ready.is_writable(), "cannot clear write readiness");
- assert!(!platform::is_hup(ready), "cannot clear HUP readiness");
-
- self.inner
- .read_readiness
- .fetch_and(!ready.as_usize(), Relaxed);
-
- if self.poll_read_ready(cx, ready)?.is_ready() {
- // Notify the current task
- cx.waker().wake_by_ref();
- }
-
- Ok(())
+ pub(crate) fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> {
+ self.registration.poll_readiness(cx, Direction::Read)
}
/// Checks the I/O resource's write readiness state.
@@ -337,100 +193,95 @@ where
///
/// This method may not be called concurrently. It takes `&self` to allow
/// calling it concurrently with `poll_read_ready`.
- pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- poll_ready!(
- self,
- mio::Ready::writable(),
- write_readiness,
- take_write_ready,
- self.inner.registration.poll_write_ready(cx)
- )
+ pub(crate) fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> {
+ self.registration.poll_readiness(cx, Direction::Write)
}
+}
- /// Resets the I/O resource's write readiness state and registers the current
- /// task to be notified once a write readiness event is received.
- ///
- /// This only clears writable readiness. HUP (on platforms that support HUP)
- /// cannot be cleared as it is a final state.
- ///
- /// After calling this function, `poll_write_ready(Ready::writable())` will
- /// return `NotReady` until a new write readiness event has been received.
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn clear_write_ready(&self, cx: &mut Context<'_>) -> io::Result<()> {
- let ready = mio::Ready::writable();
+cfg_io_readiness! {
+ impl<E: Source> PollEvented<E> {
+ pub(crate) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> {
+ self.registration.readiness(interest).await
+ }
- self.inner
- .write_readiness
- .fetch_and(!ready.as_usize(), Relaxed);
+ pub(crate) async fn async_io<F, R>(&self, interest: mio::Interest, mut op: F) -> io::Result<R>
+ where
+ F: FnMut(&E) -> io::Result<R>,
+ {
+ loop {
+ let event = self.readiness(interest).await?;
- if self.poll_write_ready(cx)?.is_ready() {
- // Notify the current task
- cx.waker().wake_by_ref();
+ match op(self.get_ref()) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.clear_readiness(event);
+ }
+ x => return x,
+ }
+ }
}
-
- Ok(())
}
}
// ===== Read / Write impls =====
-impl<E> AsyncRead for PollEvented<E>
-where
- E: Evented + Read + Unpin,
-{
+impl<E: Source + Read + Unpin> AsyncRead for PollEvented<E> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.poll_read_ready(cx, mio::Ready::readable()))?;
-
- let r = (*self).get_mut().read(buf);
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ loop {
+ let ev = ready!(self.poll_read_ready(cx))?;
+
+ // We can't assume the `Read` won't look at the read buffer,
+ // so we have to force initialization here.
+ let r = (*self).get_mut().read(buf.initialize_unfilled());
+
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- if is_wouldblock(&r) {
- self.clear_read_ready(cx, mio::Ready::readable())?;
- return Poll::Pending;
+ return Poll::Ready(r.map(|n| {
+ buf.advance(n);
+ }));
}
-
- Poll::Ready(r)
}
}
-impl<E> AsyncWrite for PollEvented<E>
-where
- E: Evented + Write + Unpin,
-{
+impl<E: Source + Write + Unpin> AsyncWrite for PollEvented<E> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
- ready!(self.poll_write_ready(cx))?;
+ loop {
+ let ev = ready!(self.poll_write_ready(cx))?;
- let r = (*self).get_mut().write(buf);
+ let r = (*self).get_mut().write(buf);
- if is_wouldblock(&r) {
- self.clear_write_ready(cx)?;
- return Poll::Pending;
- }
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- Poll::Ready(r)
+ return Poll::Ready(r);
+ }
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- ready!(self.poll_write_ready(cx))?;
+ loop {
+ let ev = ready!(self.poll_write_ready(cx))?;
- let r = (*self).get_mut().flush();
+ let r = (*self).get_mut().flush();
- if is_wouldblock(&r) {
- self.clear_write_ready(cx)?;
- return Poll::Pending;
- }
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- Poll::Ready(r)
+ return Poll::Ready(r);
+ }
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@@ -445,17 +296,17 @@ fn is_wouldblock<T>(r: &io::Result<T>) -> bool {
}
}
-impl<E: Evented + fmt::Debug> fmt::Debug for PollEvented<E> {
+impl<E: Source + fmt::Debug> fmt::Debug for PollEvented<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollEvented").field("io", &self.io).finish()
}
}
-impl<E: Evented> Drop for PollEvented<E> {
+impl<E: Source> Drop for PollEvented<E> {
fn drop(&mut self) {
- if let Some(io) = self.io.take() {
+ if let Some(mut io) = self.io.take() {
// Ignore errors
- let _ = self.inner.registration.deregister(&io);
+ let _ = self.registration.deregister(&mut io);
}
}
}
diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs
new file mode 100644
index 0000000..b64d95c
--- /dev/null
+++ b/src/io/read_buf.rs
@@ -0,0 +1,261 @@
+// This lint claims ugly casting is somehow safer than transmute, but there's
+// no evidence that is the case. Shush.
+#![allow(clippy::transmute_ptr_to_ptr)]
+
+use std::fmt;
+use std::mem::{self, MaybeUninit};
+
+/// A wrapper around a byte buffer that is incrementally filled and initialized.
+///
+/// This type is a sort of "double cursor". It tracks three regions in the
+/// buffer: a region at the beginning of the buffer that has been logically
+/// filled with data, a region that has been initialized at some point but not
+/// yet logically filled, and a region at the end that is fully uninitialized.
+/// The filled region is guaranteed to be a subset of the initialized region.
+///
+/// In summary, the contents of the buffer can be visualized as:
+///
+/// ```not_rust
+/// [ capacity ]
+/// [ filled | unfilled ]
+/// [ initialized | uninitialized ]
+/// ```
+pub struct ReadBuf<'a> {
+ buf: &'a mut [MaybeUninit<u8>],
+ filled: usize,
+ initialized: usize,
+}
+
+impl<'a> ReadBuf<'a> {
+ /// Creates a new `ReadBuf` from a fully initialized buffer.
+ #[inline]
+ pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> {
+ let initialized = buf.len();
+ let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buf) };
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized,
+ }
+ }
+
+ /// Creates a new `ReadBuf` from a fully uninitialized buffer.
+ ///
+ /// Use `assume_init` if part of the buffer is known to be already inintialized.
+ #[inline]
+ pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> {
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized: 0,
+ }
+ }
+
+ /// Returns the total capacity of the buffer.
+ #[inline]
+ pub fn capacity(&self) -> usize {
+ self.buf.len()
+ }
+
+ /// Returns a shared reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled(&self) -> &[u8] {
+ let slice = &self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable.
+ unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled_mut(&mut self) -> &mut [u8] {
+ let slice = &mut self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns a new `ReadBuf` comprised of the unfilled section up to `n`.
+ #[inline]
+ pub fn take(&mut self, n: usize) -> ReadBuf<'_> {
+ let max = std::cmp::min(self.remaining(), n);
+ // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`.
+ unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) }
+ }
+
+ /// Returns a shared reference to the initialized portion of the buffer.
+ ///
+ /// This includes the filled portion.
+ #[inline]
+ pub fn initialized(&self) -> &[u8] {
+ let slice = &self.buf[..self.initialized];
+ // safety: initialized describes how far into the buffer that the
+ // user has at some point initialized with bytes.
+ // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable.
+ unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the initialized portion of the buffer.
+ ///
+ /// This includes the filled portion.
+ #[inline]
+ pub fn initialized_mut(&mut self) -> &mut [u8] {
+ let slice = &mut self.buf[..self.initialized];
+ // safety: initialized describes how far into the buffer that the
+ // user has at some point initialized with bytes.
+ // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully
+ /// initialized.
+ ///
+ /// # Safety
+ ///
+ /// The caller must not de-initialize portions of the buffer that have already been initialized.
+ #[inline]
+ pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit<u8>] {
+ &mut self.buf[self.filled..]
+ }
+
+ /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized.
+ ///
+ /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after
+ /// the first use.
+ #[inline]
+ pub fn initialize_unfilled(&mut self) -> &mut [u8] {
+ self.initialize_unfilled_to(self.remaining())
+ }
+
+ /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is
+ /// fully initialized.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `self.remaining()` is less than `n`.
+ #[inline]
+ pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] {
+ assert!(self.remaining() >= n, "n overflows remaining");
+
+ // This can't overflow, otherwise the assert above would have failed.
+ let end = self.filled + n;
+
+ if self.initialized < end {
+ unsafe {
+ self.buf[self.initialized..end]
+ .as_mut_ptr()
+ .write_bytes(0, end - self.initialized);
+ }
+ self.initialized = end;
+ }
+
+ let slice = &mut self.buf[self.filled..end];
+ // safety: just above, we checked that the end of the buf has
+ // been initialized to some value.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns the number of bytes at the end of the slice that have not yet been filled.
+ #[inline]
+ pub fn remaining(&self) -> usize {
+ self.capacity() - self.filled
+ }
+
+ /// Clears the buffer, resetting the filled region to empty.
+ ///
+ /// The number of initialized bytes is not changed, and the contents of the buffer are not modified.
+ #[inline]
+ pub fn clear(&mut self) {
+ self.filled = 0;
+ }
+
+ /// Advances the size of the filled region of the buffer.
+ ///
+ /// The number of initialized bytes is not changed.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the filled region of the buffer would become larger than the initialized region.
+ #[inline]
+ pub fn advance(&mut self, n: usize) {
+ let new = self.filled.checked_add(n).expect("filled overflow");
+ self.set_filled(new);
+ }
+
+ /// Sets the size of the filled region of the buffer.
+ ///
+ /// The number of initialized bytes is not changed.
+ ///
+ /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for
+ /// example, by a `AsyncRead` implementation that compresses data in-place).
+ ///
+ /// # Panics
+ ///
+ /// Panics if the filled region of the buffer would become larger than the intialized region.
+ #[inline]
+ pub fn set_filled(&mut self, n: usize) {
+ assert!(
+ n <= self.initialized,
+ "filled must not become larger than initialized"
+ );
+ self.filled = n;
+ }
+
+ /// Asserts that the first `n` unfilled bytes of the buffer are initialized.
+ ///
+ /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer
+ /// bytes than are already known to be initialized.
+ ///
+ /// # Safety
+ ///
+ /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized.
+ #[inline]
+ pub unsafe fn assume_init(&mut self, n: usize) {
+ let new = self.filled + n;
+ if new > self.initialized {
+ self.initialized = new;
+ }
+ }
+
+ /// Appends data to the buffer, advancing the written position and possibly also the initialized position.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `self.remaining()` is less than `buf.len()`.
+ #[inline]
+ pub fn put_slice(&mut self, buf: &[u8]) {
+ assert!(
+ self.remaining() >= buf.len(),
+ "buf.len() must fit in remaining()"
+ );
+
+ let amt = buf.len();
+ // Cannot overflow, asserted above
+ let end = self.filled + amt;
+
+ // Safety: the length is asserted above
+ unsafe {
+ self.buf[self.filled..end]
+ .as_mut_ptr()
+ .cast::<u8>()
+ .copy_from_nonoverlapping(buf.as_ptr(), amt);
+ }
+
+ if self.initialized < end {
+ self.initialized = end;
+ }
+ self.filled = end;
+ }
+}
+
+impl fmt::Debug for ReadBuf<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReadBuf")
+ .field("filled", &self.filled)
+ .field("initialized", &self.initialized)
+ .field("capacity", &self.capacity())
+ .finish()
+ }
+}
diff --git a/src/io/registration.rs b/src/io/registration.rs
index 77fe6db..ce6cffd 100644
--- a/src/io/registration.rs
+++ b/src/io/registration.rs
@@ -1,7 +1,7 @@
-use crate::io::driver::{platform, Direction, Handle};
-use crate::util::slab::Address;
+use crate::io::driver::{Direction, Handle, ReadyEvent, ScheduledIo};
+use crate::util::slab;
-use mio::{self, Evented};
+use mio::event::Source;
use std::io;
use std::task::{Context, Poll};
@@ -38,74 +38,38 @@ cfg_io_driver! {
/// [`poll_read_ready`]: method@Self::poll_read_ready`
/// [`poll_write_ready`]: method@Self::poll_write_ready`
#[derive(Debug)]
- pub struct Registration {
+ pub(crate) struct Registration {
+ /// Handle to the associated driver.
handle: Handle,
- address: Address,
+
+ /// Reference to state stored by the driver.
+ shared: slab::Ref<ScheduledIo>,
}
}
+unsafe impl Send for Registration {}
+unsafe impl Sync for Registration {}
+
// ===== impl Registration =====
impl Registration {
- /// Registers the I/O resource with the default reactor.
- ///
- /// # Return
- ///
- /// - `Ok` if the registration happened successfully
- /// - `Err` if an error was encountered during registration
- ///
- ///
- /// # Panics
- ///
- /// This function panics if thread-local runtime is not set.
- ///
- /// The runtime is usually set implicitly when this function is called
- /// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new<T>(io: &T) -> io::Result<Registration>
- where
- T: Evented,
- {
- Registration::new_with_ready(io, mio::Ready::all())
- }
-
- /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state.
- /// `new_with_ready` should be used over `new` when you need control over the readiness state,
+ /// Registers the I/O resource with the default reactor, for a specific `mio::Interest`.
+ /// `new_with_interest` should be used over `new` when you need control over the readiness state,
/// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if
/// you are interested in those states, you will need to add them to the readiness state passed
/// to this function.
///
- /// An example to listen to read only
- ///
- /// ```rust
- /// ##[cfg(unix)]
- /// mio::Ready::from_usize(
- /// mio::Ready::readable().as_usize()
- /// | mio::unix::UnixReady::error().as_usize()
- /// | mio::unix::UnixReady::hup().as_usize()
- /// );
- /// ```
- ///
/// # Return
///
/// - `Ok` if the registration happened successfully
/// - `Err` if an error was encountered during registration
- ///
- ///
- /// # Panics
- ///
- /// This function panics if thread-local runtime is not set.
- ///
- /// The runtime is usually set implicitly when this function is called
- /// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new_with_ready<T>(io: &T, ready: mio::Ready) -> io::Result<Registration>
- where
- T: Evented,
- {
- let handle = Handle::current();
- let address = if let Some(inner) = handle.inner() {
- inner.add_source(io, ready)?
+ pub(crate) fn new_with_interest_and_handle(
+ io: &mut impl Source,
+ interest: mio::Interest,
+ handle: Handle,
+ ) -> io::Result<Registration> {
+ let shared = if let Some(inner) = handle.inner() {
+ inner.add_source(io, interest)?
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
@@ -113,7 +77,7 @@ impl Registration {
));
};
- Ok(Registration { handle, address })
+ Ok(Registration { handle, shared })
}
/// Deregisters the I/O resource from the reactor it is associated with.
@@ -132,10 +96,7 @@ impl Registration {
/// no longer result in notifications getting sent for this registration.
///
/// `Err` is returned if an error is encountered.
- pub fn deregister<T>(&mut self, io: &T) -> io::Result<()>
- where
- T: Evented,
- {
+ pub(super) fn deregister(&mut self, io: &mut impl Source) -> io::Result<()> {
let inner = match self.handle.inner() {
Some(inner) => inner,
None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")),
@@ -143,198 +104,47 @@ impl Registration {
inner.deregister_source(io)
}
- /// Polls for events on the I/O resource's read readiness stream.
- ///
- /// If the I/O resource receives a new read readiness event since the last
- /// call to `poll_read_ready`, it is returned. If it has not, the current
- /// task is notified once a new event is received.
- ///
- /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned,
- /// the function will always return `Ready(HUP)`. This should be treated as
- /// the end of the readiness stream.
- ///
- /// # Return value
- ///
- /// There are several possible return values:
- ///
- /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received
- /// a new readiness event. The readiness value is included.
- ///
- /// * `Poll::Pending` means that no new readiness events have been received
- /// since the last call to `poll_read_ready`.
- ///
- /// * `Poll::Ready(Err(err))` means that the registration has encountered an
- /// error. This could represent a permanent internal error for example.
- ///
- /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| {
- coop.made_progress();
- e
- })?;
- match v {
- Some(v) => {
- coop.made_progress();
- Poll::Ready(Ok(v))
- }
- None => Poll::Pending,
- }
- }
-
- /// Consume any pending read readiness event.
- ///
- /// This function is identical to [`poll_read_ready`] **except** that it
- /// will not notify the current task when a new event is received. As such,
- /// it is safe to call this function from outside of a task context.
- ///
- /// [`poll_read_ready`]: method@Self::poll_read_ready
- pub fn take_read_ready(&self) -> io::Result<Option<mio::Ready>> {
- self.poll_ready(Direction::Read, None)
- }
-
- /// Polls for events on the I/O resource's write readiness stream.
- ///
- /// If the I/O resource receives a new write readiness event since the last
- /// call to `poll_write_ready`, it is returned. If it has not, the current
- /// task is notified once a new event is received.
- ///
- /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned,
- /// the function will always return `Ready(HUP)`. This should be treated as
- /// the end of the readiness stream.
- ///
- /// # Return value
- ///
- /// There are several possible return values:
- ///
- /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received
- /// a new readiness event. The readiness value is included.
- ///
- /// * `Poll::Pending` means that no new readiness events have been received
- /// since the last call to `poll_write_ready`.
- ///
- /// * `Poll::Ready(Err(err))` means that the registration has encountered an
- /// error. This could represent a permanent internal error for example.
- ///
- /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| {
- coop.made_progress();
- e
- })?;
- match v {
- Some(v) => {
- coop.made_progress();
- Poll::Ready(Ok(v))
- }
- None => Poll::Pending,
- }
- }
-
- /// Consumes any pending write readiness event.
- ///
- /// This function is identical to [`poll_write_ready`] **except** that it
- /// will not notify the current task when a new event is received. As such,
- /// it is safe to call this function from outside of a task context.
- ///
- /// [`poll_write_ready`]: method@Self::poll_write_ready
- pub fn take_write_ready(&self) -> io::Result<Option<mio::Ready>> {
- self.poll_ready(Direction::Write, None)
+ pub(super) fn clear_readiness(&self, event: ReadyEvent) {
+ self.shared.clear_readiness(event);
}
/// Polls for events on the I/O resource's `direction` readiness stream.
///
/// If called with a task context, notify the task when a new event is
/// received.
- fn poll_ready(
+ pub(super) fn poll_readiness(
&self,
+ cx: &mut Context<'_>,
direction: Direction,
- cx: Option<&mut Context<'_>>,
- ) -> io::Result<Option<mio::Ready>> {
- let inner = match self.handle.inner() {
- Some(inner) => inner,
- None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")),
- };
-
- // If the task should be notified about new events, ensure that it has
- // been registered
- if let Some(ref cx) = cx {
- inner.register(self.address, direction, cx.waker().clone())
+ ) -> Poll<io::Result<ReadyEvent>> {
+ if self.handle.inner().is_none() {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone")));
}
- let mask = direction.mask();
- let mask_no_hup = (mask - platform::hup() - platform::error()).as_usize();
-
- let sched = inner.io_dispatch.get(self.address).unwrap();
+ // Keep track of task budget
+ let coop = ready!(crate::coop::poll_proceed(cx));
+ let ev = ready!(self.shared.poll_readiness(cx, direction));
+ coop.made_progress();
+ Poll::Ready(Ok(ev))
+ }
+}
- // This consumes the current readiness state **except** for HUP and
- // error. HUP and error are excluded because a) they are final states
- // and never transitition out and b) both the read AND the write
- // directions need to be able to obvserve these states.
- //
- // # Platform-specific behavior
- //
- // HUP and error readiness are platform-specific. On epoll platforms,
- // HUP has specific conditions that must be met by both peers of a
- // connection in order to be triggered.
- //
- // On epoll platforms, `EPOLLERR` is signaled through
- // `UnixReady::error()` and is important to be observable by both read
- // AND write. A specific case that `EPOLLERR` occurs is when the read
- // end of a pipe is closed. When this occurs, a peer blocked by
- // writing to the pipe should be notified.
- let curr_ready = sched
- .set_readiness(self.address, |curr| curr & (!mask_no_hup))
- .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address));
+cfg_io_readiness! {
+ impl Registration {
+ pub(super) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> {
+ use std::future::Future;
+ use std::pin::Pin;
- let mut ready = mask & mio::Ready::from_usize(curr_ready);
+ let fut = self.shared.readiness(interest);
+ pin!(fut);
- if ready.is_empty() {
- if let Some(cx) = cx {
- // Update the task info
- match direction {
- Direction::Read => sched.reader.register_by_ref(cx.waker()),
- Direction::Write => sched.writer.register_by_ref(cx.waker()),
+ crate::future::poll_fn(|cx| {
+ if self.handle.inner().is_none() {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone")));
}
- // Try again
- let curr_ready = sched
- .set_readiness(self.address, |curr| curr & (!mask_no_hup))
- .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address));
- ready = mask & mio::Ready::from_usize(curr_ready);
- }
- }
-
- if ready.is_empty() {
- Ok(None)
- } else {
- Ok(Some(ready))
+ Pin::new(&mut fut).poll(cx).map(Ok)
+ }).await
}
}
}
-
-unsafe impl Send for Registration {}
-unsafe impl Sync for Registration {}
-
-impl Drop for Registration {
- fn drop(&mut self) {
- let inner = match self.handle.inner() {
- Some(inner) => inner,
- None => return,
- };
- inner.drop_source(self.address);
- }
-}
diff --git a/src/io/seek.rs b/src/io/seek.rs
index e3b5bf6..e64205d 100644
--- a/src/io/seek.rs
+++ b/src/io/seek.rs
@@ -1,15 +1,23 @@
use crate::io::AsyncSeek;
+
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io::{self, SeekFrom};
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-/// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct Seek<'a, S: ?Sized> {
- seek: &'a mut S,
- pos: Option<SeekFrom>,
+pin_project! {
+ /// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct Seek<'a, S: ?Sized> {
+ seek: &'a mut S,
+ pos: Option<SeekFrom>,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
pub(crate) fn seek<S>(seek: &mut S, pos: SeekFrom) -> Seek<'_, S>
@@ -19,6 +27,7 @@ where
Seek {
seek,
pos: Some(pos),
+ _pin: PhantomPinned,
}
}
@@ -28,29 +37,21 @@ where
{
type Output = io::Result<u64>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
match me.pos {
- Some(pos) => match Pin::new(&mut me.seek).start_seek(cx, pos) {
- Poll::Ready(Ok(())) => {
- me.pos = None;
- Pin::new(&mut me.seek).poll_complete(cx)
+ Some(pos) => {
+ // ensure no seek in progress
+ ready!(Pin::new(&mut *me.seek).poll_complete(cx))?;
+ match Pin::new(&mut *me.seek).start_seek(*pos) {
+ Ok(()) => {
+ *me.pos = None;
+ Pin::new(&mut *me.seek).poll_complete(cx)
+ }
+ Err(e) => Poll::Ready(Err(e)),
}
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- },
- None => Pin::new(&mut me.seek).poll_complete(cx),
+ }
+ None => Pin::new(&mut *me.seek).poll_complete(cx),
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Seek<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/split.rs b/src/io/split.rs
index 134b937..fd3273e 100644
--- a/src/io/split.rs
+++ b/src/io/split.rs
@@ -4,9 +4,8 @@
//! To restore this read/write object from its `split::ReadHalf` and
//! `split::WriteHalf` use `unsplit`.
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
-use bytes::{Buf, BufMut};
use std::cell::UnsafeCell;
use std::fmt;
use std::io;
@@ -102,20 +101,11 @@ impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
let mut inner = ready!(self.inner.poll_lock(cx));
inner.stream_pin().poll_read(cx, buf)
}
-
- fn poll_read_buf<B: BufMut>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- let mut inner = ready!(self.inner.poll_lock(cx));
- inner.stream_pin().poll_read_buf(cx, buf)
- }
}
impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
@@ -137,15 +127,6 @@ impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
let mut inner = ready!(self.inner.poll_lock(cx));
inner.stream_pin().poll_shutdown(cx)
}
-
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<Result<usize, io::Error>> {
- let mut inner = ready!(self.inner.poll_lock(cx));
- inner.stream_pin().poll_write_buf(cx, buf)
- }
}
impl<T> Inner<T> {
diff --git a/src/io/stderr.rs b/src/io/stderr.rs
index 99607dc..2f624fb 100644
--- a/src/io/stderr.rs
+++ b/src/io/stderr.rs
@@ -1,4 +1,5 @@
use crate::io::blocking::Blocking;
+use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows;
use crate::io::AsyncWrite;
use std::io;
@@ -35,7 +36,7 @@ cfg_io_std! {
/// ```
#[derive(Debug)]
pub struct Stderr {
- std: Blocking<std::io::Stderr>,
+ std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stderr>>,
}
/// Constructs a new handle to the standard error of the current process.
@@ -59,7 +60,7 @@ cfg_io_std! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut stderr = io::stdout();
+ /// let mut stderr = io::stderr();
/// stderr.write_all(b"Print some error here.").await?;
/// Ok(())
/// }
@@ -67,7 +68,7 @@ cfg_io_std! {
pub fn stderr() -> Stderr {
let std = io::stderr();
Stderr {
- std: Blocking::new(std),
+ std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)),
}
}
}
diff --git a/src/io/stdin.rs b/src/io/stdin.rs
index 325b875..c9578f1 100644
--- a/src/io/stdin.rs
+++ b/src/io/stdin.rs
@@ -1,5 +1,5 @@
use crate::io::blocking::Blocking;
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use std::io;
use std::pin::Pin;
@@ -63,16 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin {
}
impl AsyncRead for Stdin {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97
- false
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut self.std).poll_read(cx, buf)
}
}
diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs
new file mode 100644
index 0000000..d21c842
--- /dev/null
+++ b/src/io/stdio_common.rs
@@ -0,0 +1,220 @@
+//! Contains utilities for stdout and stderr.
+use crate::io::AsyncWrite;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+/// # Windows
+/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest,
+/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF.
+/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
+/// # Other platforms
+/// passes data to `inner` as is
+#[derive(Debug)]
+pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
+ inner: W,
+}
+
+impl<W> SplitByUtf8BoundaryIfWindows<W> {
+ pub(crate) fn new(inner: W) -> Self {
+ Self { inner }
+ }
+}
+
+// this constant is defined by Unicode standard.
+const MAX_BYTES_PER_CHAR: usize = 4;
+
+// Subject for tweaking here
+const MAGIC_CONST: usize = 8;
+
+impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
+where
+ W: AsyncWrite + Unpin,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ mut buf: &[u8],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ // just a closure to avoid repetitive code
+ let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
+
+ // 1. Only windows stdio can suffer from non-utf8.
+ // We also check for `test` so that we can write some tests
+ // for further code. Since `AsyncWrite` can always shrink
+ // buffer at its discretion, excessive (i.e. in tests) shrinking
+ // does not break correctness.
+ // 2. If buffer is small, it will not be shrinked.
+ // That's why, it's "textness" will not change, so we don't have
+ // to fixup it.
+ if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
+ {
+ return call_inner(buf);
+ }
+
+ buf = &buf[..crate::io::blocking::MAX_BUF];
+
+ // Now there are two possibilites.
+ // If caller gave is binary buffer, we **should not** shrink it
+ // anymore, because excessive shrinking hits performance.
+ // If caller gave as binary buffer, we **must** additionaly
+ // shrink it to strip incomplete char at the end of buffer.
+ // that's why check we will perform now is allowed to have
+ // false-positive.
+
+ // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
+ // if they are (possibly incomplete) utf8, then we can be quite sure
+ // that input buffer was utf8.
+
+ let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
+ Ok(_) => true,
+ Err(err) => {
+ let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
+ incomplete_bytes < MAX_BYTES_PER_CHAR
+ }
+ };
+
+ if have_to_fix_up {
+ // We must pop several bytes at the end which form incomplete
+ // character. To achieve it, we exploit UTF8 encoding:
+ // for any code point, all bytes except first start with 0b10 prefix.
+ // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
+ let trailing_incomplete_char_size = buf
+ .iter()
+ .rev()
+ .take(MAX_BYTES_PER_CHAR)
+ .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
+ .unwrap_or(0)
+ + 1;
+ buf = &buf[..buf.len() - trailing_incomplete_char_size];
+ }
+
+ call_inner(buf)
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ Pin::new(&mut self.inner).poll_flush(cx)
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ Pin::new(&mut self.inner).poll_shutdown(cx)
+ }
+}
+
+#[cfg(test)]
+#[cfg(not(loom))]
+mod tests {
+ use crate::io::AsyncWriteExt;
+ use std::io;
+ use std::pin::Pin;
+ use std::task::Context;
+ use std::task::Poll;
+
+ const MAX_BUF: usize = 16 * 1024;
+
+ struct TextMockWriter;
+
+ impl crate::io::AsyncWrite for TextMockWriter {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ assert!(buf.len() <= MAX_BUF);
+ assert!(std::str::from_utf8(buf).is_ok());
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ struct LoggingMockWriter {
+ write_history: Vec<usize>,
+ }
+
+ impl LoggingMockWriter {
+ fn new() -> Self {
+ LoggingMockWriter {
+ write_history: Vec::new(),
+ }
+ }
+ }
+
+ impl crate::io::AsyncWrite for LoggingMockWriter {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ assert!(buf.len() <= MAX_BUF);
+ self.write_history.push(buf.len());
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ #[test]
+ fn test_splitter() {
+ let data = str::repeat("█", MAX_BUF);
+ let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
+ let fut = async move {
+ wr.write_all(data.as_bytes()).await.unwrap();
+ };
+ crate::runtime::Builder::new_current_thread()
+ .build()
+ .unwrap()
+ .block_on(fut);
+ }
+
+ #[test]
+ fn test_pseudo_text() {
+ // In this test we write a piece of binary data, whose beginning is
+ // text though. We then validate that even in this corner case buffer
+ // was not shrinked too much.
+ let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
+ let mut data: Vec<u8> = str::repeat("a", checked_count).into();
+ data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
+ let mut writer = LoggingMockWriter::new();
+ let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
+ crate::runtime::Builder::new_current_thread()
+ .build()
+ .unwrap()
+ .block_on(async {
+ splitter.write_all(&data).await.unwrap();
+ });
+ // Check that at most two writes were performed
+ assert!(writer.write_history.len() <= 2);
+ // Check that all has been written
+ assert_eq!(
+ writer.write_history.iter().copied().sum::<usize>(),
+ data.len()
+ );
+ // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked
+ // from the buffer: one because it was outside of MAX_BUF boundary, and
+ // up to one "utf8 code point".
+ assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
+ }
+}
diff --git a/src/io/stdout.rs b/src/io/stdout.rs
index 5377993..a08ed01 100644
--- a/src/io/stdout.rs
+++ b/src/io/stdout.rs
@@ -1,6 +1,6 @@
use crate::io::blocking::Blocking;
+use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows;
use crate::io::AsyncWrite;
-
use std::io;
use std::pin::Pin;
use std::task::Context;
@@ -35,7 +35,7 @@ cfg_io_std! {
/// ```
#[derive(Debug)]
pub struct Stdout {
- std: Blocking<std::io::Stdout>,
+ std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stdout>>,
}
/// Constructs a new handle to the standard output of the current process.
@@ -67,7 +67,7 @@ cfg_io_std! {
pub fn stdout() -> Stdout {
let std = io::stdout();
Stdout {
- std: Blocking::new(std),
+ std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)),
}
}
}
diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs
index 1bfab90..9e87f2f 100644
--- a/src/io/util/async_buf_read_ext.rs
+++ b/src/io/util/async_buf_read_ext.rs
@@ -14,7 +14,7 @@ cfg_io_util! {
/// Equivalent to:
///
/// ```ignore
- /// async fn read_until(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>;
+ /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>;
/// ```
///
/// This function will read bytes from the underlying stream until the
diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs
index e848a5d..0ab66c2 100644
--- a/src/io/util/async_read_ext.rs
+++ b/src/io/util/async_read_ext.rs
@@ -986,10 +986,12 @@ cfg_io_util! {
///
/// All bytes read from this source will be appended to the specified
/// buffer `buf`. This function will continuously call [`read()`] to
- /// append more data to `buf` until [`read()`][read] returns `Ok(0)`.
+ /// append more data to `buf` until [`read()`] returns `Ok(0)`.
///
/// If successful, the total number of bytes read is returned.
///
+ /// [`read()`]: AsyncReadExt::read
+ ///
/// # Errors
///
/// If a read error is encountered then the `read_to_end` operation
@@ -1018,7 +1020,7 @@ cfg_io_util! {
/// (See also the [`tokio::fs::read`] convenience function for reading from a
/// file.)
///
- /// [`tokio::fs::read`]: crate::fs::read::read
+ /// [`tokio::fs::read`]: fn@crate::fs::read
fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self>
where
Self: Unpin,
@@ -1065,7 +1067,7 @@ cfg_io_util! {
/// (See also the [`crate::fs::read_to_string`] convenience function for
/// reading from a file.)
///
- /// [`crate::fs::read_to_string`]: crate::fs::read_to_string::read_to_string
+ /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string
fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self>
where
Self: Unpin,
@@ -1078,7 +1080,11 @@ cfg_io_util! {
/// This function returns a new instance of `AsyncRead` which will read
/// at most `limit` bytes, after which it will always return EOF
/// (`Ok(0)`). Any read errors will not count towards the number of
- /// bytes read and future calls to [`read()`][read] may succeed.
+ /// bytes read and future calls to [`read()`] may succeed.
+ ///
+ /// [`read()`]: fn@crate::io::AsyncReadExt::read
+ ///
+ /// [read]: AsyncReadExt::read
///
/// # Examples
///
diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs
index c7a0f72..351900b 100644
--- a/src/io/util/async_seek_ext.rs
+++ b/src/io/util/async_seek_ext.rs
@@ -2,65 +2,73 @@ use crate::io::seek::{seek, Seek};
use crate::io::AsyncSeek;
use std::io::SeekFrom;
-/// An extension trait which adds utility methods to [`AsyncSeek`] types.
-///
-/// As a convenience, this trait may be imported using the [`prelude`]:
-///
-/// # Examples
-///
-/// ```
-/// use std::io::{Cursor, SeekFrom};
-/// use tokio::prelude::*;
-///
-/// #[tokio::main]
-/// async fn main() -> io::Result<()> {
-/// let mut cursor = Cursor::new(b"abcdefg");
-///
-/// // the `seek` method is defined by this trait
-/// cursor.seek(SeekFrom::Start(3)).await?;
-///
-/// let mut buf = [0; 1];
-/// let n = cursor.read(&mut buf).await?;
-/// assert_eq!(n, 1);
-/// assert_eq!(buf, [b'd']);
-///
-/// Ok(())
-/// }
-/// ```
-///
-/// See [module][crate::io] documentation for more details.
-///
-/// [`AsyncSeek`]: AsyncSeek
-/// [`prelude`]: crate::prelude
-pub trait AsyncSeekExt: AsyncSeek {
- /// Creates a future which will seek an IO object, and then yield the
- /// new position in the object and the object itself.
+cfg_io_util! {
+ /// An extension trait which adds utility methods to [`AsyncSeek`] types.
///
- /// In the case of an error the buffer and the object will be discarded, with
- /// the error yielded.
+ /// As a convenience, this trait may be imported using the [`prelude`]:
///
/// # Examples
///
- /// ```no_run
- /// use tokio::fs::File;
+ /// ```
+ /// use std::io::{Cursor, SeekFrom};
/// use tokio::prelude::*;
///
- /// use std::io::SeekFrom;
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let mut cursor = Cursor::new(b"abcdefg");
+ ///
+ /// // the `seek` method is defined by this trait
+ /// cursor.seek(SeekFrom::Start(3)).await?;
///
- /// # async fn dox() -> std::io::Result<()> {
- /// let mut file = File::open("foo.txt").await?;
- /// file.seek(SeekFrom::Start(6)).await?;
+ /// let mut buf = [0; 1];
+ /// let n = cursor.read(&mut buf).await?;
+ /// assert_eq!(n, 1);
+ /// assert_eq!(buf, [b'd']);
///
- /// let mut contents = vec![0u8; 10];
- /// file.read_exact(&mut contents).await?;
- /// # Ok(())
- /// # }
+ /// Ok(())
+ /// }
/// ```
- fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self>
- where
- Self: Unpin,
- {
- seek(self, pos)
+ ///
+ /// See [module][crate::io] documentation for more details.
+ ///
+ /// [`AsyncSeek`]: AsyncSeek
+ /// [`prelude`]: crate::prelude
+ pub trait AsyncSeekExt: AsyncSeek {
+ /// Creates a future which will seek an IO object, and then yield the
+ /// new position in the object and the object itself.
+ ///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>;
+ /// ```
+ ///
+ /// In the case of an error the buffer and the object will be discarded, with
+ /// the error yielded.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::fs::File;
+ /// use tokio::prelude::*;
+ ///
+ /// use std::io::SeekFrom;
+ ///
+ /// # async fn dox() -> std::io::Result<()> {
+ /// let mut file = File::open("foo.txt").await?;
+ /// file.seek(SeekFrom::Start(6)).await?;
+ ///
+ /// let mut contents = vec![0u8; 10];
+ /// file.read_exact(&mut contents).await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self>
+ where
+ Self: Unpin,
+ {
+ seek(self, pos)
+ }
}
}
diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs
index fa41097..e6ef5b2 100644
--- a/src/io/util/async_write_ext.rs
+++ b/src/io/util/async_write_ext.rs
@@ -119,6 +119,7 @@ cfg_io_util! {
write(self, src)
}
+
/// Writes a buffer into this writer, advancing the buffer's internal
/// cursor.
///
@@ -134,7 +135,7 @@ cfg_io_util! {
/// internal cursor is advanced by the number of bytes written. A
/// subsequent call to `write_buf` using the **same** `buf` value will
/// resume from the point that the first call to `write_buf` completed.
- /// A call to `write` represents *at most one* attempt to write to any
+ /// A call to `write_buf` represents *at most one* attempt to write to any
/// wrapped object.
///
/// # Return
@@ -976,6 +977,8 @@ cfg_io_util! {
/// no longer attempt to write to the stream. For example, the
/// `TcpStream` implementation will issue a `shutdown(Write)` sys call.
///
+ /// [`flush`]: fn@crate::io::AsyncWriteExt::flush
+ ///
/// # Examples
///
/// ```no_run
diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs
index a1c5990..271f61b 100644
--- a/src/io/util/buf_reader.rs
+++ b/src/io/util/buf_reader.rs
@@ -1,10 +1,8 @@
use crate::io::util::DEFAULT_BUF_SIZE;
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
-use bytes::Buf;
use pin_project_lite::pin_project;
-use std::io::{self, Read};
-use std::mem::MaybeUninit;
+use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, fmt};
@@ -44,21 +42,12 @@ impl<R: AsyncRead> BufReader<R> {
/// Creates a new `BufReader` with the specified buffer capacity.
pub fn with_capacity(capacity: usize, inner: R) -> Self {
- unsafe {
- let mut buffer = Vec::with_capacity(capacity);
- buffer.set_len(capacity);
-
- {
- // Convert to MaybeUninit
- let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]);
- inner.prepare_uninitialized_buffer(b);
- }
- Self {
- inner,
- buf: buffer.into_boxed_slice(),
- pos: 0,
- cap: 0,
- }
+ let buffer = vec![0; capacity];
+ Self {
+ inner,
+ buf: buffer.into_boxed_slice(),
+ pos: 0,
+ cap: 0,
}
}
@@ -110,25 +99,21 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
// If we don't have any buffered data and we're doing a massive read
// (larger than our internal buffer), bypass our internal buffer
// entirely.
- if self.pos == self.cap && buf.len() >= self.buf.len() {
+ if self.pos == self.cap && buf.remaining() >= self.buf.len() {
let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf));
self.discard_buffer();
return Poll::Ready(res);
}
- let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?;
- let nread = rem.read(buf)?;
- self.consume(nread);
- Poll::Ready(Ok(nread))
- }
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
+ let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
+ let amt = std::cmp::min(rem.len(), buf.remaining());
+ buf.put_slice(&rem[..amt]);
+ self.consume(amt);
+ Poll::Ready(Ok(()))
}
}
@@ -142,7 +127,9 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
// to tell the compiler that the pos..cap slice is always valid.
if *me.pos >= *me.cap {
debug_assert!(*me.pos == *me.cap);
- *me.cap = ready!(me.inner.poll_read(cx, me.buf))?;
+ let mut buf = ReadBuf::new(me.buf);
+ ready!(me.inner.poll_read(cx, &mut buf))?;
+ *me.cap = buf.filled().len();
*me.pos = 0;
}
Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
@@ -163,14 +150,6 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
self.get_pin_mut().poll_write(cx, buf)
}
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- self.get_pin_mut().poll_write_buf(cx, buf)
- }
-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_flush(cx)
}
diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs
index a56a451..cc857e2 100644
--- a/src/io/util/buf_stream.rs
+++ b/src/io/util/buf_stream.rs
@@ -1,9 +1,8 @@
use crate::io::util::{BufReader, BufWriter};
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project;
use std::io;
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -137,15 +136,10 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
- }
}
impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> {
diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs
index efd053e..5e3d4b7 100644
--- a/src/io/util/buf_writer.rs
+++ b/src/io/util/buf_writer.rs
@@ -1,10 +1,9 @@
use crate::io::util::DEFAULT_BUF_SIZE;
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project;
use std::fmt;
use std::io::{self, Write};
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -147,15 +146,10 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_read(cx, buf)
}
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.get_ref().prepare_uninitialized_buffer(buf)
- }
}
impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
diff --git a/src/io/util/chain.rs b/src/io/util/chain.rs
index 8ba9194..84f37fc 100644
--- a/src/io/util/chain.rs
+++ b/src/io/util/chain.rs
@@ -1,4 +1,4 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use pin_project_lite::pin_project;
use std::fmt;
@@ -84,26 +84,20 @@ where
T: AsyncRead,
U: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- if self.first.prepare_uninitialized_buffer(buf) {
- return true;
- }
- if self.second.prepare_uninitialized_buffer(buf) {
- return true;
- }
- false
- }
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
let me = self.project();
if !*me.done_first {
- match ready!(me.first.poll_read(cx, buf)?) {
- 0 if !buf.is_empty() => *me.done_first = true,
- n => return Poll::Ready(Ok(n)),
+ let rem = buf.remaining();
+ ready!(me.first.poll_read(cx, buf))?;
+ if buf.remaining() == rem {
+ *me.done_first = true;
+ } else {
+ return Poll::Ready(Ok(()));
}
}
me.second.poll_read(cx, buf)
diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs
index 7bfe296..c5981cf 100644
--- a/src/io/util/copy.rs
+++ b/src/io/util/copy.rs
@@ -1,30 +1,25 @@
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
- /// A future that asynchronously copies the entire contents of a reader into a
- /// writer.
- ///
- /// This struct is generally created by calling [`copy`][copy]. Please
- /// see the documentation of `copy()` for more details.
- ///
- /// [copy]: copy()
- #[derive(Debug)]
- #[must_use = "futures do nothing unless you `.await` or poll them"]
- pub struct Copy<'a, R: ?Sized, W: ?Sized> {
- reader: &'a mut R,
- read_done: bool,
- writer: &'a mut W,
- pos: usize,
- cap: usize,
- amt: u64,
- buf: Box<[u8]>,
- }
+/// A future that asynchronously copies the entire contents of a reader into a
+/// writer.
+#[derive(Debug)]
+#[must_use = "futures do nothing unless you `.await` or poll them"]
+struct Copy<'a, R: ?Sized, W: ?Sized> {
+ reader: &'a mut R,
+ read_done: bool,
+ writer: &'a mut W,
+ pos: usize,
+ cap: usize,
+ amt: u64,
+ buf: Box<[u8]>,
+}
+cfg_io_util! {
/// Asynchronously copies the entire contents of a reader into a writer.
///
/// This function returns a future that will continuously read data from
@@ -58,7 +53,7 @@ cfg_io_util! {
/// # Ok(())
/// # }
/// ```
- pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W>
+ pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
@@ -71,7 +66,7 @@ cfg_io_util! {
pos: 0,
cap: 0,
buf: vec![0; 2048].into_boxed_slice(),
- }
+ }.await
}
}
@@ -88,7 +83,9 @@ where
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
- let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?;
+ let mut buf = ReadBuf::new(&mut me.buf);
+ ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
+ let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
@@ -122,14 +119,3 @@ where
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/copy_buf.rs b/src/io/util/copy_buf.rs
new file mode 100644
index 0000000..6831580
--- /dev/null
+++ b/src/io/util/copy_buf.rs
@@ -0,0 +1,102 @@
+use crate::io::{AsyncBufRead, AsyncWrite};
+use std::future::Future;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+cfg_io_util! {
+ /// A future that asynchronously copies the entire contents of a reader into a
+ /// writer.
+ ///
+ /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please
+ /// see the documentation of `copy_buf()` for more details.
+ ///
+ /// [copy_buf]: copy_buf()
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ struct CopyBuf<'a, R: ?Sized, W: ?Sized> {
+ reader: &'a mut R,
+ writer: &'a mut W,
+ amt: u64,
+ }
+
+ /// Asynchronously copies the entire contents of a reader into a writer.
+ ///
+ /// This function returns a future that will continuously read data from
+ /// `reader` and then write it into `writer` in a streaming fashion until
+ /// `reader` returns EOF.
+ ///
+ /// On success, the total number of bytes that were copied from `reader` to
+ /// `writer` is returned.
+ ///
+ ///
+ /// # Errors
+ ///
+ /// The returned future will finish with an error will return an error
+ /// immediately if any call to `poll_fill_buf` or `poll_write` returns an
+ /// error.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::io;
+ ///
+ /// # async fn dox() -> std::io::Result<()> {
+ /// let mut reader: &[u8] = b"hello";
+ /// let mut writer: Vec<u8> = vec![];
+ ///
+ /// io::copy_buf(&mut reader, &mut writer).await?;
+ ///
+ /// assert_eq!(b"hello", &writer[..]);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
+ where
+ R: AsyncBufRead + Unpin + ?Sized,
+ W: AsyncWrite + Unpin + ?Sized,
+ {
+ CopyBuf {
+ reader,
+ writer,
+ amt: 0,
+ }.await
+ }
+}
+
+impl<R, W> Future for CopyBuf<'_, R, W>
+where
+ R: AsyncBufRead + Unpin + ?Sized,
+ W: AsyncWrite + Unpin + ?Sized,
+{
+ type Output = io::Result<u64>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ loop {
+ let me = &mut *self;
+ let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?;
+ if buffer.is_empty() {
+ ready!(Pin::new(&mut self.writer).poll_flush(cx))?;
+ return Poll::Ready(Ok(self.amt));
+ }
+
+ let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?;
+ if i == 0 {
+ return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
+ }
+ self.amt += i as u64;
+ Pin::new(&mut *self.reader).consume(i);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn assert_unpin() {
+ use std::marker::PhantomPinned;
+ crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>();
+ }
+}
diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs
index 576058d..f964d18 100644
--- a/src/io/util/empty.rs
+++ b/src/io/util/empty.rs
@@ -1,4 +1,4 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use std::fmt;
use std::io;
@@ -47,16 +47,13 @@ cfg_io_util! {
}
impl AsyncRead for Empty {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- false
- }
#[inline]
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
- _: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(Ok(0))
+ _: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/flush.rs b/src/io/util/flush.rs
index 534a516..88d60b8 100644
--- a/src/io/util/flush.rs
+++ b/src/io/util/flush.rs
@@ -1,18 +1,24 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future used to fully flush an I/O object.
///
/// Created by the [`AsyncWriteExt::flush`][flush] function.
/// [flush]: crate::io::AsyncWriteExt::flush
#[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Flush<'a, A: ?Sized> {
a: &'a mut A,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +27,10 @@ pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A>
where
A: AsyncWrite + Unpin + ?Sized,
{
- Flush { a }
+ Flush {
+ a,
+ _pin: PhantomPinned,
+ }
}
impl<A> Future for Flush<'_, A>
@@ -30,19 +39,8 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
Pin::new(&mut *me.a).poll_flush(cx)
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Flush<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs
index ee27400..b41f04a 100644
--- a/src/io/util/lines.rs
+++ b/src/io/util/lines.rs
@@ -83,8 +83,7 @@ impl<R> Lines<R>
where
R: AsyncBufRead,
{
- #[doc(hidden)]
- pub fn poll_next_line(
+ fn poll_next_line(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<String>>> {
diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs
new file mode 100644
index 0000000..e91a932
--- /dev/null
+++ b/src/io/util/mem.rs
@@ -0,0 +1,223 @@
+//! In-process memory IO types.
+
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
+use crate::loom::sync::Mutex;
+
+use bytes::{Buf, BytesMut};
+use std::{
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll, Waker},
+};
+
+/// A bidirectional pipe to read and write bytes in memory.
+///
+/// A pair of `DuplexStream`s are created together, and they act as a "channel"
+/// that can be used as in-memory IO types. Writing to one of the pairs will
+/// allow that data to be read from the other, and vice versa.
+///
+/// # Example
+///
+/// ```
+/// # async fn ex() -> std::io::Result<()> {
+/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
+/// let (mut client, mut server) = tokio::io::duplex(64);
+///
+/// client.write_all(b"ping").await?;
+///
+/// let mut buf = [0u8; 4];
+/// server.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"ping");
+///
+/// server.write_all(b"pong").await?;
+///
+/// client.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"pong");
+/// # Ok(())
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct DuplexStream {
+ read: Arc<Mutex<Pipe>>,
+ write: Arc<Mutex<Pipe>>,
+}
+
+/// A unidirectional IO over a piece of memory.
+///
+/// Data can be written to the pipe, and reading will return that data.
+#[derive(Debug)]
+struct Pipe {
+ /// The buffer storing the bytes written, also read from.
+ ///
+ /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
+ /// functionality already. Additionally, it can try to copy data in the
+ /// same buffer if there read index has advanced far enough.
+ buffer: BytesMut,
+ /// Determines if the write side has been closed.
+ is_closed: bool,
+ /// The maximum amount of bytes that can be written before returning
+ /// `Poll::Pending`.
+ max_buf_size: usize,
+ /// If the `read` side has been polled and is pending, this is the waker
+ /// for that parked task.
+ read_waker: Option<Waker>,
+ /// If the `write` side has filled the `max_buf_size` and returned
+ /// `Poll::Pending`, this is the waker for that parked task.
+ write_waker: Option<Waker>,
+}
+
+// ===== impl DuplexStream =====
+
+/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
+///
+/// The `max_buf_size` argument is the maximum amount of bytes that can be
+/// written to a side before the write returns `Poll::Pending`.
+pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
+ let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+ let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+
+ (
+ DuplexStream {
+ read: one.clone(),
+ write: two.clone(),
+ },
+ DuplexStream {
+ read: two,
+ write: one,
+ },
+ )
+}
+
+impl AsyncRead for DuplexStream {
+ // Previous rustc required this `self` to be `mut`, even though newer
+ // versions recognize it isn't needed to call `lock()`. So for
+ // compatibility, we include the `mut` and `allow` the lint.
+ //
+ // See https://github.com/rust-lang/rust/issues/73592
+ #[allow(unused_mut)]
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
+ }
+}
+
+impl AsyncWrite for DuplexStream {
+ #[allow(unused_mut)]
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock()).poll_flush(cx)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
+ }
+}
+
+impl Drop for DuplexStream {
+ fn drop(&mut self) {
+ // notify the other side of the closure
+ self.write.lock().close();
+ }
+}
+
+// ===== impl Pipe =====
+
+impl Pipe {
+ fn new(max_buf_size: usize) -> Self {
+ Pipe {
+ buffer: BytesMut::new(),
+ is_closed: false,
+ max_buf_size,
+ read_waker: None,
+ write_waker: None,
+ }
+ }
+
+ fn close(&mut self) {
+ self.is_closed = true;
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ }
+}
+
+impl AsyncRead for Pipe {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ if self.buffer.has_remaining() {
+ let max = self.buffer.remaining().min(buf.remaining());
+ buf.put_slice(&self.buffer[..max]);
+ self.buffer.advance(max);
+ if max > 0 {
+ // The passed `buf` might have been empty, don't wake up if
+ // no bytes have been moved.
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+ }
+ Poll::Ready(Ok(()))
+ } else if self.is_closed {
+ Poll::Ready(Ok(()))
+ } else {
+ self.read_waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+}
+
+impl AsyncWrite for Pipe {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ if self.is_closed {
+ return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
+ }
+ let avail = self.max_buf_size - self.buffer.len();
+ if avail == 0 {
+ self.write_waker = Some(cx.waker().clone());
+ return Poll::Pending;
+ }
+
+ let len = buf.len().min(avail);
+ self.buffer.extend_from_slice(&buf[..len]);
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ Poll::Ready(Ok(len))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ _: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ self.close();
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs
index c4754ab..e75ea03 100644
--- a/src/io/util/mod.rs
+++ b/src/io/util/mod.rs
@@ -25,7 +25,10 @@ cfg_io_util! {
mod chain;
mod copy;
- pub use copy::{copy, Copy};
+ pub use copy::copy;
+
+ mod copy_buf;
+ pub use copy_buf::copy_buf;
mod empty;
pub use empty::{empty, Empty};
@@ -35,6 +38,9 @@ cfg_io_util! {
mod lines;
pub use lines::Lines;
+ mod mem;
+ pub use mem::{duplex, DuplexStream};
+
mod read;
mod read_buf;
mod read_exact;
@@ -60,11 +66,6 @@ cfg_io_util! {
mod split;
pub use split::Split;
- cfg_stream! {
- mod stream_reader;
- pub use stream_reader::{stream_reader, StreamReader};
- }
-
mod take;
pub use take::Take;
diff --git a/src/io/util/read.rs b/src/io/util/read.rs
index a8ca370..edc9d5a 100644
--- a/src/io/util/read.rs
+++ b/src/io/util/read.rs
@@ -1,7 +1,9 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -15,10 +17,14 @@ pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R>
where
R: AsyncRead + Unpin + ?Sized,
{
- Read { reader, buf }
+ Read {
+ reader,
+ buf,
+ _pin: PhantomPinned,
+ }
}
-cfg_io_util! {
+pin_project! {
/// A future which can be used to easily read available number of bytes to fill
/// a buffer.
///
@@ -28,6 +34,9 @@ cfg_io_util! {
pub struct Read<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -37,19 +46,10 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.reader).poll_read(cx, me.buf)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Read<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
+ let mut buf = ReadBuf::new(*me.buf);
+ ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?;
+ Poll::Ready(Ok(buf.filled().len()))
}
}
diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs
index 6ee3d24..696deef 100644
--- a/src/io/util/read_buf.rs
+++ b/src/io/util/read_buf.rs
@@ -1,8 +1,10 @@
use crate::io::AsyncRead;
use bytes::BufMut;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -11,16 +13,22 @@ where
R: AsyncRead + Unpin,
B: BufMut,
{
- ReadBuf { reader, buf }
+ ReadBuf {
+ reader,
+ buf,
+ _pin: PhantomPinned,
+ }
}
-cfg_io_util! {
+pin_project! {
/// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf).
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadBuf<'a, R, B> {
reader: &'a mut R,
buf: &'a mut B,
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -31,8 +39,34 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ use crate::io::ReadBuf;
+ use std::mem::MaybeUninit;
+
+ let me = self.project();
+
+ if !me.buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = {
+ let dst = me.buf.bytes_mut();
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ let ptr = buf.filled().as_ptr();
+ ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?);
+
+ // Ensure the pointer does not change from under us
+ assert_eq!(ptr, buf.filled().as_ptr());
+ buf.filled().len()
+ };
+
+ // Safety: This is guaranteed to be the number of initialized (and read)
+ // bytes due to the invariants provided by `ReadBuf::filled`.
+ unsafe {
+ me.buf.advance_mut(n);
+ }
+
+ Poll::Ready(Ok(n))
}
}
diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs
index 86b8412..1e8150e 100644
--- a/src/io/util/read_exact.rs
+++ b/src/io/util/read_exact.rs
@@ -1,7 +1,9 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -17,12 +19,12 @@ where
{
ReadExact {
reader,
- buf,
- pos: 0,
+ buf: ReadBuf::new(buf),
+ _pin: PhantomPinned,
}
}
-cfg_io_util! {
+pin_project! {
/// Creates a future which will read exactly enough bytes to fill `buf`,
/// returning an error if EOF is hit sooner.
///
@@ -31,8 +33,10 @@ cfg_io_util! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadExact<'a, A: ?Sized> {
reader: &'a mut A,
- buf: &'a mut [u8],
- pos: usize,
+ buf: ReadBuf<'a>,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -46,32 +50,20 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let mut me = self.project();
+
loop {
// if our buffer is empty, then we need to read some data to continue.
- if self.pos < self.buf.len() {
- let me = &mut *self;
- let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?;
- me.pos += n;
- if n == 0 {
+ let rem = me.buf.remaining();
+ if rem != 0 {
+ ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?;
+ if me.buf.remaining() == rem {
return Err(eof()).into();
}
- }
-
- if self.pos >= self.buf.len() {
- return Poll::Ready(Ok(self.pos));
+ } else {
+ return Poll::Ready(Ok(me.buf.capacity()));
}
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadExact<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs
index 9d37dc7..5b9fb7b 100644
--- a/src/io/util/read_int.rs
+++ b/src/io/util/read_int.rs
@@ -1,10 +1,11 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::ErrorKind::UnexpectedEof;
+use std::marker::PhantomPinned;
use std::mem::size_of;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -16,11 +17,15 @@ macro_rules! reader {
($name:ident, $ty:ty, $reader:ident, $bytes:expr) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<R> {
#[pin]
src: R,
buf: [u8; $bytes],
read: u8,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -30,6 +35,7 @@ macro_rules! reader {
src,
buf: [0; $bytes],
read: 0,
+ _pin: PhantomPinned,
}
}
}
@@ -48,17 +54,19 @@ macro_rules! reader {
}
while *me.read < $bytes as u8 {
- *me.read += match me
- .src
- .as_mut()
- .poll_read(cx, &mut me.buf[*me.read as usize..])
- {
+ let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]);
+
+ *me.read += match me.src.as_mut().poll_read(cx, &mut buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
- Poll::Ready(Ok(0)) => {
- return Poll::Ready(Err(UnexpectedEof.into()));
+ Poll::Ready(Ok(())) => {
+ let n = buf.filled().len();
+ if n == 0 {
+ return Poll::Ready(Err(UnexpectedEof.into()));
+ }
+
+ n as u8
}
- Poll::Ready(Ok(n)) => n as u8,
};
}
@@ -75,15 +83,22 @@ macro_rules! reader8 {
pin_project! {
/// Future returned from `read_u8`
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<R> {
#[pin]
reader: R,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<R> $name<R> {
pub(crate) fn new(reader: R) -> $name<R> {
- $name { reader }
+ $name {
+ reader,
+ _pin: PhantomPinned,
+ }
}
}
@@ -97,12 +112,17 @@ macro_rules! reader8 {
let me = self.project();
let mut buf = [0; 1];
- match me.reader.poll_read(cx, &mut buf[..]) {
+ let mut buf = ReadBuf::new(&mut buf);
+ match me.reader.poll_read(cx, &mut buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
- Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())),
- Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)),
- Poll::Ready(Ok(_)) => unreachable!(),
+ Poll::Ready(Ok(())) => {
+ if buf.filled().len() == 0 {
+ return Poll::Ready(Err(UnexpectedEof.into()));
+ }
+
+ Poll::Ready(Ok(buf.filled()[0] as $ty))
+ }
}
}
}
diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs
index d625a76..d38ffaf 100644
--- a/src/io/util/read_line.rs
+++ b/src/io/util/read_line.rs
@@ -1,26 +1,32 @@
use crate::io::util::read_until::read_until_internal;
use crate::io::AsyncBufRead;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
+use std::string::FromUtf8Error;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadLine<'a, R: ?Sized> {
reader: &'a mut R,
- /// This is the buffer we were provided. It will be replaced with an empty string
- /// while reading to postpone utf-8 handling until after reading.
+ // This is the buffer we were provided. It will be replaced with an empty string
+ // while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
- /// The actual allocation of the string is moved into a vector instead.
+ // The actual allocation of the string is moved into this vector instead.
buf: Vec<u8>,
- /// The number of bytes appended to buf. This can be less than buf.len() if
- /// the buffer was not empty when the operation was started.
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -33,6 +39,7 @@ where
buf: mem::replace(string, String::new()).into_bytes(),
output: string,
read: 0,
+ _pin: PhantomPinned,
}
}
@@ -42,31 +49,33 @@ fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_re
*output = String::from_utf8(vector).expect("The original data must be valid utf-8.");
}
-pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
- reader: Pin<&mut R>,
- cx: &mut Context<'_>,
+/// This handles the various failure cases and puts the string back into `output`.
+///
+/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line`
+/// disagree on what should happen when an IO error occurs.
+pub(super) fn finish_string_read(
+ io_res: io::Result<usize>,
+ utf8_res: Result<String, FromUtf8Error>,
+ read: usize,
output: &mut String,
- buf: &mut Vec<u8>,
- read: &mut usize,
+ truncate_on_io_error: bool,
) -> Poll<io::Result<usize>> {
- let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
- let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
-
- // At this point both buf and output are empty. The allocation is in utf8_res.
-
- debug_assert!(buf.is_empty());
match (io_res, utf8_res) {
(Ok(num_bytes), Ok(string)) => {
- debug_assert_eq!(*read, 0);
+ debug_assert_eq!(read, 0);
*output = string;
Poll::Ready(Ok(num_bytes))
}
(Err(io_err), Ok(string)) => {
*output = string;
+ if truncate_on_io_error {
+ let original_len = output.len() - read;
+ output.truncate(original_len);
+ }
Poll::Ready(Err(io_err))
}
(Ok(num_bytes), Err(utf8_err)) => {
- debug_assert_eq!(*read, 0);
+ debug_assert_eq!(read, 0);
put_back_original_data(output, utf8_err.into_bytes(), num_bytes);
Poll::Ready(Err(io::Error::new(
@@ -75,35 +84,36 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
)))
}
(Err(io_err), Err(utf8_err)) => {
- put_back_original_data(output, utf8_err.into_bytes(), *read);
+ put_back_original_data(output, utf8_err.into_bytes(), read);
Poll::Ready(Err(io_err))
}
}
}
-impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
- type Output = io::Result<usize>;
+pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
+ reader: Pin<&mut R>,
+ cx: &mut Context<'_>,
+ output: &mut String,
+ buf: &mut Vec<u8>,
+ read: &mut usize,
+) -> Poll<io::Result<usize>> {
+ let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
+ let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- output,
- buf,
- read,
- } = &mut *self;
+ // At this point both buf and output are empty. The allocation is in utf8_res.
- read_line_internal(Pin::new(reader), cx, output, buf, read)
- }
+ debug_assert!(buf.is_empty());
+ debug_assert!(output.is_empty());
+ finish_string_read(io_res, utf8_res, *read, output, false)
}
-#[cfg(test)]
-mod tests {
- use super::*;
+impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
+ type Output = io::Result<usize>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadLine<'_, PhantomPinned>>();
+ read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read)
}
}
diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs
index a2cd99b..a974625 100644
--- a/src/io/util/read_to_end.rs
+++ b/src/io/util/read_to_end.rs
@@ -1,92 +1,105 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
-use std::mem::MaybeUninit;
+use std::marker::PhantomPinned;
+use std::mem::{self, MaybeUninit};
use std::pin::Pin;
use std::task::{Context, Poll};
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
-pub struct ReadToEnd<'a, R: ?Sized> {
- reader: &'a mut R,
- buf: &'a mut Vec<u8>,
- start_len: usize,
+pin_project! {
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct ReadToEnd<'a, R: ?Sized> {
+ reader: &'a mut R,
+ buf: &'a mut Vec<u8>,
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
+ read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
-pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, R>
+pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R>
where
R: AsyncRead + Unpin + ?Sized,
{
- let start_len = buf.len();
ReadToEnd {
reader,
- buf,
- start_len,
+ buf: buffer,
+ read: 0,
+ _pin: PhantomPinned,
}
}
-struct Guard<'a> {
- buf: &'a mut Vec<u8>,
- len: usize,
-}
-
-impl Drop for Guard<'_> {
- fn drop(&mut self) {
- unsafe {
- self.buf.set_len(self.len);
+pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
+ buf: &mut Vec<u8>,
+ mut reader: Pin<&mut R>,
+ num_read: &mut usize,
+ cx: &mut Context<'_>,
+) -> Poll<io::Result<usize>> {
+ loop {
+ // safety: The caller promised to prepare the buffer.
+ let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx));
+ match ret {
+ Err(err) => return Poll::Ready(Err(err)),
+ Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))),
+ Ok(num) => {
+ *num_read += num;
+ }
}
}
}
-// This uses an adaptive system to extend the vector when it fills. We want to
-// avoid paying to allocate and zero a huge chunk of memory if the reader only
-// has 4 bytes while still making large reads if the reader does have a ton
-// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every
-// time is 4,500 times (!) slower than this if the reader has a very small
-// amount of data to return.
-//
-// Because we're extending the buffer with uninitialized data for trusted
-// readers, we need to make sure to truncate that if any of this panics.
-pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
- mut rd: Pin<&mut R>,
- cx: &mut Context<'_>,
+/// Tries to read from the provided AsyncRead.
+///
+/// The length of the buffer is increased by the number of bytes read.
+fn poll_read_to_end<R: AsyncRead + ?Sized>(
buf: &mut Vec<u8>,
- start_len: usize,
+ read: Pin<&mut R>,
+ cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
- let mut g = Guard {
- len: buf.len(),
- buf,
- };
- let ret;
- loop {
- if g.len == g.buf.len() {
- unsafe {
- g.buf.reserve(32);
- let capacity = g.buf.capacity();
- g.buf.set_len(capacity);
+ // This uses an adaptive system to extend the vector when it fills. We want to
+ // avoid paying to allocate and zero a huge chunk of memory if the reader only
+ // has 4 bytes while still making large reads if the reader does have a ton
+ // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every
+ // time is 4,500 times (!) slower than this if the reader has a very small
+ // amount of data to return.
+ reserve(buf, 32);
- let b = &mut *(&mut g.buf[g.len..] as *mut [u8] as *mut [MaybeUninit<u8>]);
+ let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf));
- rd.prepare_uninitialized_buffer(b);
- }
- }
+ ready!(read.poll_read(cx, &mut unused_capacity))?;
- match ready!(rd.as_mut().poll_read(cx, &mut g.buf[g.len..])) {
- Ok(0) => {
- ret = Poll::Ready(Ok(g.len - start_len));
- break;
- }
- Ok(n) => g.len += n,
- Err(e) => {
- ret = Poll::Ready(Err(e));
- break;
- }
- }
+ let n = unused_capacity.filled().len();
+ let new_len = buf.len() + n;
+
+ // This should no longer even be possible in safe Rust. An implementor
+ // would need to have unsafely *replaced* the buffer inside `ReadBuf`,
+ // which... yolo?
+ assert!(new_len <= buf.capacity());
+ unsafe {
+ buf.set_len(new_len);
}
+ Poll::Ready(Ok(n))
+}
- ret
+/// Allocates more memory and ensures that the unused capacity is prepared for use
+/// with the `AsyncRead`.
+fn reserve(buf: &mut Vec<u8>, bytes: usize) {
+ if buf.capacity() - buf.len() >= bytes {
+ return;
+ }
+ buf.reserve(bytes);
+}
+
+/// Returns the unused capacity of the provided vector.
+fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] {
+ let uninit = bytes::BufMut::bytes_mut(buf);
+ unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) }
}
impl<A> Future for ReadToEnd<'_, A>
@@ -95,19 +108,9 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let this = &mut *self;
- read_to_end_internal(Pin::new(&mut this.reader), cx, this.buf, this.start_len)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadToEnd<'_, PhantomPinned>>();
+ read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx)
}
}
diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs
index cab0505..e463203 100644
--- a/src/io/util/read_to_string.rs
+++ b/src/io/util/read_to_string.rs
@@ -1,58 +1,71 @@
+use crate::io::util::read_line::finish_string_read;
use crate::io::util::read_to_end::read_to_end_internal;
use crate::io::AsyncRead;
+use pin_project_lite::pin_project;
use std::future::Future;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadToString<'a, R: ?Sized> {
reader: &'a mut R,
- buf: &'a mut String,
- bytes: Vec<u8>,
- start_len: usize,
+ // This is the buffer we were provided. It will be replaced with an empty string
+ // while reading to postpone utf-8 handling until after reading.
+ output: &'a mut String,
+ // The actual allocation of the string is moved into this vector instead.
+ buf: Vec<u8>,
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
+ read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
-pub(crate) fn read_to_string<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadToString<'a, R>
+pub(crate) fn read_to_string<'a, R>(
+ reader: &'a mut R,
+ string: &'a mut String,
+) -> ReadToString<'a, R>
where
R: AsyncRead + ?Sized + Unpin,
{
- let start_len = buf.len();
+ let buf = mem::replace(string, String::new()).into_bytes();
ReadToString {
reader,
- bytes: mem::replace(buf, String::new()).into_bytes(),
buf,
- start_len,
+ output: string,
+ read: 0,
+ _pin: PhantomPinned,
}
}
-fn read_to_string_internal<R: AsyncRead + ?Sized>(
+/// # Safety
+///
+/// Before first calling this method, the unused capacity must have been
+/// prepared for use with the provided AsyncRead. This can be done using the
+/// `prepare_buffer` function in `read_to_end.rs`.
+unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>(
reader: Pin<&mut R>,
+ output: &mut String,
+ buf: &mut Vec<u8>,
+ read: &mut usize,
cx: &mut Context<'_>,
- buf: &mut String,
- bytes: &mut Vec<u8>,
- start_len: usize,
) -> Poll<io::Result<usize>> {
- let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?;
- match String::from_utf8(mem::replace(bytes, Vec::new())) {
- Ok(string) => {
- debug_assert!(buf.is_empty());
- *buf = string;
- Poll::Ready(Ok(ret))
- }
- Err(e) => {
- *bytes = e.into_bytes();
- Poll::Ready(Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "stream did not contain valid UTF-8",
- )))
- }
- }
+ let io_res = ready!(read_to_end_internal(buf, reader, read, cx));
+ let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
+
+ // At this point both buf and output are empty. The allocation is in utf8_res.
+
+ debug_assert!(buf.is_empty());
+ debug_assert!(output.is_empty());
+ finish_string_read(io_res, utf8_res, *read, output, true)
}
impl<A> Future for ReadToString<'_, A>
@@ -61,31 +74,10 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- buf,
- bytes,
- start_len,
- } = &mut *self;
- let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len);
- if let Poll::Ready(Err(_)) = ret {
- // Put back the original string.
- bytes.truncate(*start_len);
- **buf = String::from_utf8(mem::replace(bytes, Vec::new()))
- .expect("original string no longer utf-8");
- }
- ret
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadToString<'_, PhantomPinned>>();
+ // safety: The constructor of ReadToString called `prepare_buffer`.
+ unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) }
}
}
diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs
index 78dac8c..3599cff 100644
--- a/src/io/util/read_until.rs
+++ b/src/io/util/read_until.rs
@@ -1,12 +1,14 @@
use crate::io::AsyncBufRead;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method.
/// The delimeter is included in the resulting vector.
#[derive(Debug)]
@@ -15,9 +17,12 @@ cfg_io_util! {
reader: &'a mut R,
delimeter: u8,
buf: &'a mut Vec<u8>,
- /// The number of bytes appended to buf. This can be less than buf.len() if
- /// the buffer was not empty when the operation was started.
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -34,6 +39,7 @@ where
delimeter,
buf,
read: 0,
+ _pin: PhantomPinned,
}
}
@@ -66,24 +72,8 @@ pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>(
impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> {
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- delimeter,
- buf,
- read,
- } = &mut *self;
- read_until_internal(Pin::new(reader), cx, *delimeter, buf, read)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadUntil<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read)
}
}
diff --git a/src/io/util/repeat.rs b/src/io/util/repeat.rs
index eeef7cc..1142765 100644
--- a/src/io/util/repeat.rs
+++ b/src/io/util/repeat.rs
@@ -1,4 +1,4 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use std::io;
use std::pin::Pin;
@@ -47,19 +47,17 @@ cfg_io_util! {
}
impl AsyncRead for Repeat {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- false
- }
#[inline]
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- for byte in &mut *buf {
- *byte = self.byte;
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ // TODO: could be faster, but should we unsafe it?
+ while buf.remaining() != 0 {
+ buf.put_slice(&[self.byte]);
}
- Poll::Ready(Ok(buf.len()))
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/shutdown.rs b/src/io/util/shutdown.rs
index 33ac0ac..6d30b00 100644
--- a/src/io/util/shutdown.rs
+++ b/src/io/util/shutdown.rs
@@ -1,18 +1,24 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future used to shutdown an I/O object.
///
/// Created by the [`AsyncWriteExt::shutdown`][shutdown] function.
/// [shutdown]: crate::io::AsyncWriteExt::shutdown
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Shutdown<'a, A: ?Sized> {
a: &'a mut A,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +27,10 @@ pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A>
where
A: AsyncWrite + Unpin + ?Sized,
{
- Shutdown { a }
+ Shutdown {
+ a,
+ _pin: PhantomPinned,
+ }
}
impl<A> Future for Shutdown<'_, A>
@@ -30,19 +39,8 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
- Pin::new(&mut *me.a).poll_shutdown(cx)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Shutdown<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ Pin::new(me.a).poll_shutdown(cx)
}
}
diff --git a/src/io/util/split.rs b/src/io/util/split.rs
index f552ed5..492e26a 100644
--- a/src/io/util/split.rs
+++ b/src/io/util/split.rs
@@ -65,8 +65,7 @@ impl<R> Split<R>
where
R: AsyncBufRead,
{
- #[doc(hidden)]
- pub fn poll_next_segment(
+ fn poll_next_segment(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<Vec<u8>>>> {
diff --git a/src/io/util/stream_reader.rs b/src/io/util/stream_reader.rs
deleted file mode 100644
index b98f8bd..0000000
--- a/src/io/util/stream_reader.rs
+++ /dev/null
@@ -1,184 +0,0 @@
-use crate::io::{AsyncBufRead, AsyncRead};
-use crate::stream::Stream;
-use bytes::{Buf, BufMut};
-use pin_project_lite::pin_project;
-use std::io;
-use std::mem::MaybeUninit;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-pin_project! {
- /// Convert a stream of byte chunks into an [`AsyncRead`].
- ///
- /// This type is usually created using the [`stream_reader`] function.
- ///
- /// [`AsyncRead`]: crate::io::AsyncRead
- /// [`stream_reader`]: crate::io::stream_reader
- #[derive(Debug)]
- #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
- #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
- pub struct StreamReader<S, B> {
- #[pin]
- inner: S,
- chunk: Option<B>,
- }
-}
-
-/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead).
-///
-/// # Example
-///
-/// ```
-/// use bytes::Bytes;
-/// use tokio::io::{stream_reader, AsyncReadExt};
-/// # #[tokio::main]
-/// # async fn main() -> std::io::Result<()> {
-///
-/// // Create a stream from an iterator.
-/// let stream = tokio::stream::iter(vec![
-/// Ok(Bytes::from_static(&[0, 1, 2, 3])),
-/// Ok(Bytes::from_static(&[4, 5, 6, 7])),
-/// Ok(Bytes::from_static(&[8, 9, 10, 11])),
-/// ]);
-///
-/// // Convert it to an AsyncRead.
-/// let mut read = stream_reader(stream);
-///
-/// // Read five bytes from the stream.
-/// let mut buf = [0; 5];
-/// read.read_exact(&mut buf).await?;
-/// assert_eq!(buf, [0, 1, 2, 3, 4]);
-///
-/// // Read the rest of the current chunk.
-/// assert_eq!(read.read(&mut buf).await?, 3);
-/// assert_eq!(&buf[..3], [5, 6, 7]);
-///
-/// // Read the next chunk.
-/// assert_eq!(read.read(&mut buf).await?, 4);
-/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
-///
-/// // We have now reached the end.
-/// assert_eq!(read.read(&mut buf).await?, 0);
-///
-/// # Ok(())
-/// # }
-/// ```
-#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
-#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
-pub fn stream_reader<S, B>(stream: S) -> StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- StreamReader::new(stream)
-}
-
-impl<S, B> StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- /// Convert the provided stream into an `AsyncRead`.
- fn new(stream: S) -> Self {
- Self {
- inner: stream,
- chunk: None,
- }
- }
- /// Do we have a chunk and is it non-empty?
- fn has_chunk(self: Pin<&mut Self>) -> bool {
- if let Some(chunk) = self.project().chunk {
- chunk.remaining() > 0
- } else {
- false
- }
- }
-}
-
-impl<S, B> AsyncRead for StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- if buf.is_empty() {
- return Poll::Ready(Ok(0));
- }
-
- let inner_buf = match self.as_mut().poll_fill_buf(cx) {
- Poll::Ready(Ok(buf)) => buf,
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => return Poll::Pending,
- };
- let len = std::cmp::min(inner_buf.len(), buf.len());
- (&mut buf[..len]).copy_from_slice(&inner_buf[..len]);
-
- self.consume(len);
- Poll::Ready(Ok(len))
- }
- fn poll_read_buf<BM: BufMut>(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut BM,
- ) -> Poll<io::Result<usize>>
- where
- Self: Sized,
- {
- if !buf.has_remaining_mut() {
- return Poll::Ready(Ok(0));
- }
-
- let inner_buf = match self.as_mut().poll_fill_buf(cx) {
- Poll::Ready(Ok(buf)) => buf,
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => return Poll::Pending,
- };
- let len = std::cmp::min(inner_buf.len(), buf.remaining_mut());
- buf.put_slice(&inner_buf[..len]);
-
- self.consume(len);
- Poll::Ready(Ok(len))
- }
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-}
-
-impl<S, B> AsyncBufRead for StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
- loop {
- if self.as_mut().has_chunk() {
- // This unwrap is very sad, but it can't be avoided.
- let buf = self.project().chunk.as_ref().unwrap().bytes();
- return Poll::Ready(Ok(buf));
- } else {
- match self.as_mut().project().inner.poll_next(cx) {
- Poll::Ready(Some(Ok(chunk))) => {
- // Go around the loop in case the chunk is empty.
- *self.as_mut().project().chunk = Some(chunk);
- }
- Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)),
- Poll::Ready(None) => return Poll::Ready(Ok(&[])),
- Poll::Pending => return Poll::Pending,
- }
- }
- }
- }
- fn consume(self: Pin<&mut Self>, amt: usize) {
- if amt > 0 {
- self.project()
- .chunk
- .as_mut()
- .expect("No chunk present")
- .advance(amt);
- }
- }
-}
diff --git a/src/io/util/take.rs b/src/io/util/take.rs
index 5d6bd90..b5e90c9 100644
--- a/src/io/util/take.rs
+++ b/src/io/util/take.rs
@@ -1,7 +1,6 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use pin_project_lite::pin_project;
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, io};
@@ -76,24 +75,27 @@ impl<R: AsyncRead> Take<R> {
}
impl<R: AsyncRead> AsyncRead for Take<R> {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<Result<usize, io::Error>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<(), io::Error>> {
if self.limit_ == 0 {
- return Poll::Ready(Ok(0));
+ return Poll::Ready(Ok(()));
}
let me = self.project();
- let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize;
- let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?;
+ let mut b = buf.take(*me.limit_ as usize);
+ ready!(me.inner.poll_read(cx, &mut b))?;
+ let n = b.filled().len();
+
+ // We need to update the original ReadBuf
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.advance(n);
*me.limit_ -= n as u64;
- Poll::Ready(Ok(n))
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/write.rs b/src/io/util/write.rs
index 433a421..92169eb 100644
--- a/src/io/util/write.rs
+++ b/src/io/util/write.rs
@@ -1,17 +1,22 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Write<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +26,11 @@ pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W>
where
W: AsyncWrite + Unpin + ?Sized,
{
- Write { writer, buf }
+ Write {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W> Future for Write<'_, W>
@@ -30,8 +39,8 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
Pin::new(&mut *me.writer).poll_write(cx, me.buf)
}
}
diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs
index 898006c..e59d41e 100644
--- a/src/io/util/write_all.rs
+++ b/src/io/util/write_all.rs
@@ -1,17 +1,22 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteAll<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -19,7 +24,11 @@ pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a,
where
W: AsyncWrite + Unpin + ?Sized,
{
- WriteAll { writer, buf }
+ WriteAll {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W> Future for WriteAll<'_, W>
@@ -28,13 +37,13 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ let me = self.project();
while !me.buf.is_empty() {
- let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?;
+ let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?;
{
- let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n);
- me.buf = rest;
+ let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n);
+ *me.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -44,14 +53,3 @@ where
Poll::Ready(Ok(()))
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<WriteAll<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs
index cedfde6..1310e5c 100644
--- a/src/io/util/write_buf.rs
+++ b/src/io/util/write_buf.rs
@@ -1,18 +1,22 @@
use crate::io::AsyncWrite;
use bytes::Buf;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteBuf<'a, W, B> {
writer: &'a mut W,
buf: &'a mut B,
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -23,7 +27,11 @@ where
W: AsyncWrite + Unpin,
B: Buf,
{
- WriteBuf { writer, buf }
+ WriteBuf {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W, B> Future for WriteBuf<'_, W, B>
@@ -33,8 +41,15 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
+
+ if !me.buf.has_remaining() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?;
+ me.buf.advance(n);
+ Poll::Ready(Ok(n))
}
}
diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs
index ee992de..13bc191 100644
--- a/src/io/util/write_int.rs
+++ b/src/io/util/write_int.rs
@@ -4,6 +4,7 @@ use bytes::BufMut;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem::size_of;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -15,20 +16,25 @@ macro_rules! writer {
($name:ident, $ty:ty, $writer:ident, $bytes:expr) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<W> {
#[pin]
dst: W,
buf: [u8; $bytes],
written: u8,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<W> $name<W> {
pub(crate) fn new(w: W, value: $ty) -> Self {
- let mut writer = $name {
+ let mut writer = Self {
buf: [0; $bytes],
written: 0,
dst: w,
+ _pin: PhantomPinned,
};
BufMut::$writer(&mut &mut writer.buf[..], value);
writer
@@ -72,16 +78,24 @@ macro_rules! writer8 {
($name:ident, $ty:ty) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<W> {
#[pin]
dst: W,
byte: $ty,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<W> $name<W> {
pub(crate) fn new(dst: W, byte: $ty) -> Self {
- Self { dst, byte }
+ Self {
+ dst,
+ byte,
+ _pin: PhantomPinned,
+ }
}
}
diff --git a/src/lib.rs b/src/lib.rs
index 5775be3..66e266c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-#![doc(html_root_url = "https://docs.rs/tokio/0.2.22")]
+#![doc(html_root_url = "https://docs.rs/tokio/0.3.1")]
#![allow(
clippy::cognitive_complexity,
clippy::large_enum_variant,
@@ -10,22 +10,21 @@
rust_2018_idioms,
unreachable_pub
)]
-#![deny(intra_doc_link_resolution_failure)]
+#![cfg_attr(docsrs, deny(broken_intra_doc_links))]
#![doc(test(
no_crate_inject,
attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
))]
#![cfg_attr(docsrs, feature(doc_cfg))]
-#![cfg_attr(docsrs, feature(doc_alias))]
-//! A runtime for writing reliable, asynchronous, and slim applications.
+//! A runtime for writing reliable network applications without compromising speed.
//!
//! Tokio is an event-driven, non-blocking I/O platform for writing asynchronous
//! applications with the Rust programming language. At a high level, it
//! provides a few major components:
//!
//! * Tools for [working with asynchronous tasks][tasks], including
-//! [synchronization primitives and channels][sync] and [timeouts, delays, and
+//! [synchronization primitives and channels][sync] and [timeouts, sleeps, and
//! intervals][time].
//! * APIs for [performing asynchronous I/O][io], including [TCP and UDP][net] sockets,
//! [filesystem][fs] operations, and [process] and [signal] management.
@@ -45,7 +44,7 @@
//! [signal]: crate::signal
//! [fs]: crate::fs
//! [runtime]: crate::runtime
-//! [website]: https://tokio.rs/docs/overview/
+//! [website]: https://tokio.rs/tokio/tutorial
//!
//! # A Tour of Tokio
//!
@@ -58,59 +57,9 @@
//! enabling the `full` feature flag:
//!
//! ```toml
-//! tokio = { version = "0.2", features = ["full"] }
+//! tokio = { version = "0.3", features = ["full"] }
//! ```
//!
-//! ## Feature flags
-//!
-//! Tokio uses a set of [feature flags] to reduce the amount of compiled code. It
-//! is possible to just enable certain features over others. By default, Tokio
-//! does not enable any features but allows one to enable a subset for their use
-//! case. Below is a list of the available feature flags. You may also notice
-//! above each function, struct and trait there is listed one or more feature flags
-//! that are required for that item to be used. If you are new to Tokio it is
-//! recommended that you use the `full` feature flag which will enable all public APIs.
-//! Beware though that this will pull in many extra dependencies that you may not
-//! need.
-//!
-//! - `full`: Enables all Tokio public API features listed below.
-//! - `rt-core`: Enables `tokio::spawn` and the basic (single-threaded) scheduler.
-//! - `rt-threaded`: Enables the heavier, multi-threaded, work-stealing scheduler.
-//! - `rt-util`: Enables non-scheduler utilities.
-//! - `io-driver`: Enables the `mio` based IO driver.
-//! - `io-util`: Enables the IO based `Ext` traits.
-//! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types.
-//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`.
-//! - `tcp`: Enables all `tokio::net::tcp` types.
-//! - `udp`: Enables all `tokio::net::udp` types.
-//! - `uds`: Enables all `tokio::net::unix` types.
-//! - `time`: Enables `tokio::time` types and allows the schedulers to enable
-//! the built in timer.
-//! - `process`: Enables `tokio::process` types.
-//! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros.
-//! - `sync`: Enables all `tokio::sync` types.
-//! - `stream`: Enables optional `Stream` implementations for types within Tokio.
-//! - `signal`: Enables all `tokio::signal` types.
-//! - `fs`: Enables `tokio::fs` types.
-//! - `dns`: Enables async `tokio::net::ToSocketAddrs`.
-//! - `test-util`: Enables testing based infrastructure for the Tokio runtime.
-//! - `blocking`: Enables `block_in_place` and `spawn_blocking`.
-//!
-//! _Note: `AsyncRead` and `AsyncWrite` traits do not require any features and are
-//! always available._
-//!
-//! ### Internal features
-//!
-//! These features do not expose any new API, but influence internal
-//! implementation aspects of Tokio, and can pull in additional
-//! dependencies. They are not included in `full`:
-//!
-//! - `parking_lot`: As a potential optimization, use the _parking_lot_ crate's
-//! synchronization primitives internally. MSRV may increase according to the
-//! _parking_lot_ release in use.
-//!
-//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section
-//!
//! ### Authoring applications
//!
//! Tokio is great for writing applications and most users in this case shouldn't
@@ -123,7 +72,7 @@
//! This example shows the quickest way to get started with Tokio.
//!
//! ```toml
-//! tokio = { version = "0.2", features = ["full"] }
+//! tokio = { version = "0.3", features = ["full"] }
//! ```
//!
//! ### Authoring libraries
@@ -139,7 +88,7 @@
//! needs to `tokio::spawn` and use a `TcpStream`.
//!
//! ```toml
-//! tokio = { version = "0.2", features = ["rt-core", "tcp"] }
+//! tokio = { version = "0.3", features = ["rt", "net"] }
//! ```
//!
//! ## Working With Tasks
@@ -153,7 +102,7 @@
//! * Functions for [running blocking operations][blocking] in an asynchronous
//! task context.
//!
-//! The [`tokio::task`] module is present only when the "rt-core" feature flag
+//! The [`tokio::task`] module is present only when the "rt" feature flag
//! is enabled.
//!
//! [tasks]: task/index.html#what-are-tasks
@@ -184,25 +133,26 @@
//!
//! The [`tokio::time`] module provides utilities for tracking time and
//! scheduling work. This includes functions for setting [timeouts][timeout] for
-//! tasks, [delaying][delay] work to run in the future, or [repeating an operation at an
+//! tasks, [sleeping][sleep] work to run in the future, or [repeating an operation at an
//! interval][interval].
//!
//! In order to use `tokio::time`, the "time" feature flag must be enabled.
//!
//! [`tokio::time`]: crate::time
-//! [delay]: crate::time::delay_for()
+//! [sleep]: crate::time::sleep()
//! [interval]: crate::time::interval()
//! [timeout]: crate::time::timeout()
//!
//! Finally, Tokio provides a _runtime_ for executing asynchronous tasks. Most
//! applications can use the [`#[tokio::main]`][main] macro to run their code on the
-//! Tokio runtime. In use-cases where manual control over the runtime is
-//! required, the [`tokio::runtime`] module provides APIs for configuring and
-//! managing runtimes.
-//!
-//! Using the runtime requires the "rt-core" or "rt-threaded" feature flags, to
-//! enable the basic [single-threaded scheduler][rt-core] and the [thread-pool
-//! scheduler][rt-threaded], respectively. See the [`runtime` module
+//! Tokio runtime. However, this macro provides only basic configuration options. As
+//! an alternative, the [`tokio::runtime`] module provides more powerful APIs for configuring
+//! and managing runtimes. You should use that module if the `#[tokio::main]` macro doesn't
+//! provide the functionality you need.
+//!
+//! Using the runtime requires the "rt" or "rt-multi-thread" feature flags, to
+//! enable the basic [single-threaded scheduler][rt] and the [thread-pool
+//! scheduler][rt-multi-thread], respectively. See the [`runtime` module
//! documentation][rt-features] for details. In addition, the "macros" feature
//! flag enables the `#[tokio::main]` and `#[tokio::test]` attributes.
//!
@@ -210,8 +160,8 @@
//! [`tokio::runtime`]: crate::runtime
//! [`Builder`]: crate::runtime::Builder
//! [`Runtime`]: crate::runtime::Runtime
-//! [rt-core]: runtime/index.html#basic-scheduler
-//! [rt-threaded]: runtime/index.html#threaded-scheduler
+//! [rt]: runtime/index.html#basic-scheduler
+//! [rt-multi-thread]: runtime/index.html#threaded-scheduler
//! [rt-features]: runtime/index.html#runtime-scheduler
//!
//! ## CPU-bound tasks and blocking code
@@ -268,8 +218,7 @@
//! the [`AsyncRead`], [`AsyncWrite`], and [`AsyncBufRead`] traits. In addition,
//! when the "io-util" feature flag is enabled, it also provides combinators and
//! functions for working with these traits, forming as an asynchronous
-//! counterpart to [`std::io`]. When the "io-driver" feature flag is enabled, it
-//! also provides utilities for library authors implementing I/O resources.
+//! counterpart to [`std::io`].
//!
//! Tokio also includes APIs for performing various kinds of I/O and interacting
//! with the operating system asynchronously. These include:
@@ -307,7 +256,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
@@ -337,6 +286,50 @@
//! }
//! }
//! ```
+//!
+//! ## Feature flags
+//!
+//! Tokio uses a set of [feature flags] to reduce the amount of compiled code. It
+//! is possible to just enable certain features over others. By default, Tokio
+//! does not enable any features but allows one to enable a subset for their use
+//! case. Below is a list of the available feature flags. You may also notice
+//! above each function, struct and trait there is listed one or more feature flags
+//! that are required for that item to be used. If you are new to Tokio it is
+//! recommended that you use the `full` feature flag which will enable all public APIs.
+//! Beware though that this will pull in many extra dependencies that you may not
+//! need.
+//!
+//! - `full`: Enables all Tokio public API features listed below.
+//! - `rt`: Enables `tokio::spawn`, the basic (current thread) scheduler,
+//! and non-scheduler utilities.
+//! - `rt-multi-thread`: Enables the heavier, multi-threaded, work-stealing scheduler.
+//! - `io-util`: Enables the IO based `Ext` traits.
+//! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types.
+//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`.
+//! - `time`: Enables `tokio::time` types and allows the schedulers to enable
+//! the built in timer.
+//! - `process`: Enables `tokio::process` types.
+//! - `macros`: Enables `#[tokio::main]` and `#[tokio::test]` macros.
+//! - `sync`: Enables all `tokio::sync` types.
+//! - `stream`: Enables optional `Stream` implementations for types within Tokio.
+//! - `signal`: Enables all `tokio::signal` types.
+//! - `fs`: Enables `tokio::fs` types.
+//! - `test-util`: Enables testing based infrastructure for the Tokio runtime.
+//!
+//! _Note: `AsyncRead` and `AsyncWrite` traits do not require any features and are
+//! always available._
+//!
+//! ### Internal features
+//!
+//! These features do not expose any new API, but influence internal
+//! implementation aspects of Tokio, and can pull in additional
+//! dependencies.
+//!
+//! - `parking_lot`: As a potential optimization, use the _parking_lot_ crate's
+//! synchronization primitives internally. MSRV may increase according to the
+//! _parking_lot_ release in use.
+//!
+//! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section
// Includes re-exports used by macros.
//
@@ -350,8 +343,7 @@ cfg_fs! {
pub mod fs;
}
-#[doc(hidden)]
-pub mod future;
+mod future;
pub mod io;
pub mod net;
@@ -365,7 +357,12 @@ cfg_process! {
pub mod process;
}
-pub mod runtime;
+#[cfg(any(feature = "net", feature = "fs", feature = "io-std"))]
+mod blocking;
+
+cfg_rt! {
+ pub mod runtime;
+}
pub(crate) mod coop;
@@ -373,6 +370,13 @@ cfg_signal! {
pub mod signal;
}
+cfg_signal_internal! {
+ #[cfg(not(feature = "signal"))]
+ #[allow(dead_code)]
+ #[allow(unreachable_pub)]
+ pub(crate) mod signal;
+}
+
cfg_stream! {
pub mod stream;
}
@@ -384,8 +388,8 @@ cfg_not_sync! {
mod sync;
}
-cfg_rt_core! {
- pub mod task;
+pub mod task;
+cfg_rt! {
pub use task::spawn;
}
@@ -402,31 +406,31 @@ cfg_macros! {
#[doc(hidden)]
pub use tokio_macros::select_priv_declare_output_enum;
- doc_rt_core! {
- cfg_rt_threaded! {
+ cfg_rt! {
+ cfg_rt_multi_thread! {
// This is the docs.rs case (with all features) so make sure macros
// is included in doc(cfg).
#[cfg(not(test))] // Work around for rust-lang/rust#62127
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
- pub use tokio_macros::main_threaded as main;
+ pub use tokio_macros::main;
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
- pub use tokio_macros::test_threaded as test;
+ pub use tokio_macros::test;
}
- cfg_not_rt_threaded! {
+ cfg_not_rt_multi_thread! {
#[cfg(not(test))] // Work around for rust-lang/rust#62127
- pub use tokio_macros::main_basic as main;
- pub use tokio_macros::test_basic as test;
+ pub use tokio_macros::main_rt as main;
+ pub use tokio_macros::test_rt as test;
}
}
- // Maintains old behavior
- cfg_not_rt_core! {
+ // Always fail if rt is not enabled.
+ cfg_not_rt! {
#[cfg(not(test))]
- pub use tokio_macros::main;
- pub use tokio_macros::test;
+ pub use tokio_macros::main_fail as main;
+ pub use tokio_macros::test_fail as test;
}
}
diff --git a/src/loom/mocked.rs b/src/loom/mocked.rs
index 7891395..367d59b 100644
--- a/src/loom/mocked.rs
+++ b/src/loom/mocked.rs
@@ -1,5 +1,32 @@
pub(crate) use loom::*;
+pub(crate) mod sync {
+
+ pub(crate) use loom::sync::MutexGuard;
+
+ #[derive(Debug)]
+ pub(crate) struct Mutex<T>(loom::sync::Mutex<T>);
+
+ #[allow(dead_code)]
+ impl<T> Mutex<T> {
+ #[inline]
+ pub(crate) fn new(t: T) -> Mutex<T> {
+ Mutex(loom::sync::Mutex::new(t))
+ }
+
+ #[inline]
+ pub(crate) fn lock(&self) -> MutexGuard<'_, T> {
+ self.0.lock().unwrap()
+ }
+
+ #[inline]
+ pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
+ self.0.try_lock().ok()
+ }
+ }
+ pub(crate) use loom::sync::*;
+}
+
pub(crate) mod rand {
pub(crate) fn seed() -> u64 {
1
diff --git a/src/loom/mod.rs b/src/loom/mod.rs
index 56a41f2..5957b53 100644
--- a/src/loom/mod.rs
+++ b/src/loom/mod.rs
@@ -1,6 +1,8 @@
//! This module abstracts over `loom` and `std::sync` depending on whether we
//! are running tests or not.
+#![allow(unused)]
+
#[cfg(not(all(test, loom)))]
mod std;
#[cfg(not(all(test, loom)))]
diff --git a/src/loom/std/atomic_ptr.rs b/src/loom/std/atomic_ptr.rs
index f7fd56c..236645f 100644
--- a/src/loom/std/atomic_ptr.rs
+++ b/src/loom/std/atomic_ptr.rs
@@ -1,5 +1,5 @@
use std::fmt;
-use std::ops::Deref;
+use std::ops::{Deref, DerefMut};
/// `AtomicPtr` providing an additional `load_unsync` function.
pub(crate) struct AtomicPtr<T> {
@@ -21,6 +21,12 @@ impl<T> Deref for AtomicPtr<T> {
}
}
+impl<T> DerefMut for AtomicPtr<T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
+
impl<T> fmt::Debug for AtomicPtr<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(fmt)
diff --git a/src/loom/std/atomic_u16.rs b/src/loom/std/atomic_u16.rs
index 7039097..c1c5312 100644
--- a/src/loom/std/atomic_u16.rs
+++ b/src/loom/std/atomic_u16.rs
@@ -11,7 +11,7 @@ unsafe impl Send for AtomicU16 {}
unsafe impl Sync for AtomicU16 {}
impl AtomicU16 {
- pub(crate) fn new(val: u16) -> AtomicU16 {
+ pub(crate) const fn new(val: u16) -> AtomicU16 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU16::new(val));
AtomicU16 { inner }
}
diff --git a/src/loom/std/atomic_u32.rs b/src/loom/std/atomic_u32.rs
index 6f786c5..61f95fb 100644
--- a/src/loom/std/atomic_u32.rs
+++ b/src/loom/std/atomic_u32.rs
@@ -11,7 +11,7 @@ unsafe impl Send for AtomicU32 {}
unsafe impl Sync for AtomicU32 {}
impl AtomicU32 {
- pub(crate) fn new(val: u32) -> AtomicU32 {
+ pub(crate) const fn new(val: u32) -> AtomicU32 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU32::new(val));
AtomicU32 { inner }
}
diff --git a/src/loom/std/atomic_u8.rs b/src/loom/std/atomic_u8.rs
index 4fcd0df..408aea3 100644
--- a/src/loom/std/atomic_u8.rs
+++ b/src/loom/std/atomic_u8.rs
@@ -11,7 +11,7 @@ unsafe impl Send for AtomicU8 {}
unsafe impl Sync for AtomicU8 {}
impl AtomicU8 {
- pub(crate) fn new(val: u8) -> AtomicU8 {
+ pub(crate) const fn new(val: u8) -> AtomicU8 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU8::new(val));
AtomicU8 { inner }
}
diff --git a/src/loom/std/atomic_usize.rs b/src/loom/std/atomic_usize.rs
index 0fe998f..0d5f36e 100644
--- a/src/loom/std/atomic_usize.rs
+++ b/src/loom/std/atomic_usize.rs
@@ -11,7 +11,7 @@ unsafe impl Send for AtomicUsize {}
unsafe impl Sync for AtomicUsize {}
impl AtomicUsize {
- pub(crate) fn new(val: usize) -> AtomicUsize {
+ pub(crate) const fn new(val: usize) -> AtomicUsize {
let inner = UnsafeCell::new(std::sync::atomic::AtomicUsize::new(val));
AtomicUsize { inner }
}
diff --git a/src/loom/std/mod.rs b/src/loom/std/mod.rs
index 60ee56a..9525286 100644
--- a/src/loom/std/mod.rs
+++ b/src/loom/std/mod.rs
@@ -6,6 +6,7 @@ mod atomic_u32;
mod atomic_u64;
mod atomic_u8;
mod atomic_usize;
+mod mutex;
#[cfg(feature = "parking_lot")]
mod parking_lot;
mod unsafe_cell;
@@ -14,7 +15,12 @@ pub(crate) mod cell {
pub(crate) use super::unsafe_cell::UnsafeCell;
}
-#[cfg(any(feature = "sync", feature = "io-driver"))]
+#[cfg(any(
+ feature = "net",
+ feature = "process",
+ feature = "signal",
+ feature = "sync",
+))]
pub(crate) mod future {
pub(crate) use crate::sync::AtomicWaker;
}
@@ -55,9 +61,10 @@ pub(crate) mod sync {
#[cfg(not(feature = "parking_lot"))]
#[allow(unused_imports)]
- pub(crate) use std::sync::{
- Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult,
- };
+ pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult};
+
+ #[cfg(not(feature = "parking_lot"))]
+ pub(crate) use crate::loom::std::mutex::Mutex;
pub(crate) mod atomic {
pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr;
@@ -72,12 +79,12 @@ pub(crate) mod sync {
}
pub(crate) mod sys {
- #[cfg(feature = "rt-threaded")]
+ #[cfg(feature = "rt-multi-thread")]
pub(crate) fn num_cpus() -> usize {
usize::max(1, num_cpus::get())
}
- #[cfg(not(feature = "rt-threaded"))]
+ #[cfg(not(feature = "rt-multi-thread"))]
pub(crate) fn num_cpus() -> usize {
1
}
diff --git a/src/loom/std/mutex.rs b/src/loom/std/mutex.rs
new file mode 100644
index 0000000..bf14d62
--- /dev/null
+++ b/src/loom/std/mutex.rs
@@ -0,0 +1,31 @@
+use std::sync::{self, MutexGuard, TryLockError};
+
+/// Adapter for `std::Mutex` that removes the poisoning aspects
+// from its api
+#[derive(Debug)]
+pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>);
+
+#[allow(dead_code)]
+impl<T> Mutex<T> {
+ #[inline]
+ pub(crate) fn new(t: T) -> Mutex<T> {
+ Mutex(sync::Mutex::new(t))
+ }
+
+ #[inline]
+ pub(crate) fn lock(&self) -> MutexGuard<'_, T> {
+ match self.0.lock() {
+ Ok(guard) => guard,
+ Err(p_err) => p_err.into_inner(),
+ }
+ }
+
+ #[inline]
+ pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
+ match self.0.try_lock() {
+ Ok(guard) => Some(guard),
+ Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()),
+ Err(TryLockError::WouldBlock) => None,
+ }
+ }
+}
diff --git a/src/loom/std/parking_lot.rs b/src/loom/std/parking_lot.rs
index 25d94af..c03190f 100644
--- a/src/loom/std/parking_lot.rs
+++ b/src/loom/std/parking_lot.rs
@@ -3,7 +3,7 @@
//!
//! This can be extended to additional types/methods as required.
-use std::sync::{LockResult, TryLockError, TryLockResult};
+use std::sync::LockResult;
use std::time::Duration;
// Types that do not need wrapping
@@ -27,16 +27,20 @@ impl<T> Mutex<T> {
}
#[inline]
- pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
- Ok(self.0.lock())
+ #[cfg(all(feature = "parking_lot", not(all(loom, test)),))]
+ #[cfg_attr(docsrs, doc(cfg(all(feature = "parking_lot",))))]
+ pub(crate) const fn const_new(t: T) -> Mutex<T> {
+ Mutex(parking_lot::const_mutex(t))
}
#[inline]
- pub(crate) fn try_lock(&self) -> TryLockResult<MutexGuard<'_, T>> {
- match self.0.try_lock() {
- Some(guard) => Ok(guard),
- None => Err(TryLockError::WouldBlock),
- }
+ pub(crate) fn lock(&self) -> MutexGuard<'_, T> {
+ self.0.lock()
+ }
+
+ #[inline]
+ pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
+ self.0.try_lock()
}
// Note: Additional methods `is_poisoned` and `into_inner`, can be
diff --git a/src/loom/std/unsafe_cell.rs b/src/loom/std/unsafe_cell.rs
index f2b03d8..66c1d79 100644
--- a/src/loom/std/unsafe_cell.rs
+++ b/src/loom/std/unsafe_cell.rs
@@ -2,7 +2,7 @@
pub(crate) struct UnsafeCell<T>(std::cell::UnsafeCell<T>);
impl<T> UnsafeCell<T> {
- pub(crate) fn new(data: T) -> UnsafeCell<T> {
+ pub(crate) const fn new(data: T) -> UnsafeCell<T> {
UnsafeCell(std::cell::UnsafeCell::new(data))
}
diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs
index 4b77544..2792911 100644
--- a/src/macros/cfg.rs
+++ b/src/macros/cfg.rs
@@ -1,97 +1,30 @@
#![allow(unused_macros)]
-macro_rules! cfg_resource_drivers {
- ($($item:item)*) => {
- $(
- #[cfg(any(feature = "io-driver", feature = "time"))]
- $item
- )*
- }
-}
-
-macro_rules! cfg_blocking {
- ($($item:item)*) => {
- $(
- #[cfg(feature = "blocking")]
- #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
- $item
- )*
- }
-}
-
-/// Enables blocking API internals
-macro_rules! cfg_blocking_impl {
- ($($item:item)*) => {
- $(
- #[cfg(any(
- feature = "blocking",
- feature = "fs",
- feature = "dns",
- feature = "io-std",
- feature = "rt-threaded",
- ))]
- $item
- )*
- }
-}
-
-/// Enables blocking API internals
-macro_rules! cfg_blocking_impl_or_task {
- ($($item:item)*) => {
- $(
- #[cfg(any(
- feature = "blocking",
- feature = "fs",
- feature = "dns",
- feature = "io-std",
- feature = "rt-threaded",
- feature = "task",
- ))]
- $item
- )*
- }
-}
-
/// Enables enter::block_on
macro_rules! cfg_block_on {
($($item:item)*) => {
$(
#[cfg(any(
- feature = "blocking",
feature = "fs",
- feature = "dns",
+ feature = "net",
feature = "io-std",
- feature = "rt-core",
+ feature = "rt",
))]
$item
)*
}
}
-/// Enables blocking API internals
-macro_rules! cfg_not_blocking_impl {
- ($($item:item)*) => {
- $(
- #[cfg(not(any(
- feature = "blocking",
- feature = "fs",
- feature = "dns",
- feature = "io-std",
- feature = "rt-threaded",
- )))]
- $item
- )*
- }
-}
-
/// Enables internal `AtomicWaker` impl
macro_rules! cfg_atomic_waker_impl {
($($item:item)*) => {
$(
#[cfg(any(
- feature = "io-driver",
+ feature = "net",
+ feature = "process",
+ feature = "rt",
+ feature = "signal",
feature = "time",
- all(feature = "rt-core", feature = "rt-util")
))]
#[cfg(not(loom))]
$item
@@ -99,16 +32,6 @@ macro_rules! cfg_atomic_waker_impl {
}
}
-macro_rules! cfg_dns {
- ($($item:item)*) => {
- $(
- #[cfg(feature = "dns")]
- #[cfg_attr(docsrs, doc(cfg(feature = "dns")))]
- $item
- )*
- }
-}
-
macro_rules! cfg_fs {
($($item:item)*) => {
$(
@@ -128,8 +51,16 @@ macro_rules! cfg_io_blocking {
macro_rules! cfg_io_driver {
($($item:item)*) => {
$(
- #[cfg(feature = "io-driver")]
- #[cfg_attr(docsrs, doc(cfg(feature = "io-driver")))]
+ #[cfg(any(
+ feature = "net",
+ feature = "process",
+ all(unix, feature = "signal"),
+ ))]
+ #[cfg_attr(docsrs, doc(cfg(any(
+ feature = "net",
+ feature = "process",
+ all(unix, feature = "signal"),
+ ))))]
$item
)*
}
@@ -138,7 +69,20 @@ macro_rules! cfg_io_driver {
macro_rules! cfg_not_io_driver {
($($item:item)*) => {
$(
- #[cfg(not(feature = "io-driver"))]
+ #[cfg(not(any(
+ feature = "net",
+ feature = "process",
+ all(unix, feature = "signal"),
+ )))]
+ $item
+ )*
+ }
+}
+
+macro_rules! cfg_io_readiness {
+ ($($item:item)*) => {
+ $(
+ #[cfg(feature = "net")]
$item
)*
}
@@ -193,115 +137,142 @@ macro_rules! cfg_macros {
}
}
-macro_rules! cfg_process {
+macro_rules! cfg_net {
($($item:item)*) => {
$(
- #[cfg(feature = "process")]
- #[cfg_attr(docsrs, doc(cfg(feature = "process")))]
- #[cfg(not(loom))]
+ #[cfg(feature = "net")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "net")))]
$item
)*
}
}
-macro_rules! cfg_signal {
+macro_rules! cfg_net_unix {
($($item:item)*) => {
$(
- #[cfg(feature = "signal")]
- #[cfg_attr(docsrs, doc(cfg(feature = "signal")))]
- #[cfg(not(loom))]
+ #[cfg(all(unix, feature = "net"))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "net")))]
$item
)*
}
}
-macro_rules! cfg_stream {
+macro_rules! cfg_process {
($($item:item)*) => {
$(
- #[cfg(feature = "stream")]
- #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
+ #[cfg(feature = "process")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "process")))]
+ #[cfg(not(loom))]
$item
)*
}
}
-macro_rules! cfg_sync {
+macro_rules! cfg_process_driver {
+ ($($item:item)*) => {
+ #[cfg(unix)]
+ #[cfg(not(loom))]
+ cfg_process! { $($item)* }
+ }
+}
+
+macro_rules! cfg_not_process_driver {
($($item:item)*) => {
$(
- #[cfg(feature = "sync")]
- #[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
+ #[cfg(not(all(unix, not(loom), feature = "process")))]
$item
)*
}
}
-macro_rules! cfg_not_sync {
+macro_rules! cfg_signal {
($($item:item)*) => {
- $( #[cfg(not(feature = "sync"))] $item )*
+ $(
+ #[cfg(feature = "signal")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "signal")))]
+ #[cfg(not(loom))]
+ $item
+ )*
}
}
-macro_rules! cfg_rt_core {
+macro_rules! cfg_signal_internal {
($($item:item)*) => {
$(
- #[cfg(feature = "rt-core")]
+ #[cfg(any(feature = "signal", all(unix, feature = "process")))]
+ #[cfg(not(loom))]
$item
)*
}
}
-macro_rules! doc_rt_core {
+macro_rules! cfg_not_signal_internal {
($($item:item)*) => {
$(
- #[cfg(feature = "rt-core")]
- #[cfg_attr(docsrs, doc(cfg(feature = "rt-core")))]
+ #[cfg(any(loom, not(unix), not(any(feature = "signal", all(unix, feature = "process")))))]
$item
)*
}
}
-macro_rules! cfg_not_rt_core {
+macro_rules! cfg_stream {
($($item:item)*) => {
- $( #[cfg(not(feature = "rt-core"))] $item )*
+ $(
+ #[cfg(feature = "stream")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
+ $item
+ )*
}
}
-macro_rules! cfg_rt_threaded {
+macro_rules! cfg_sync {
($($item:item)*) => {
$(
- #[cfg(feature = "rt-threaded")]
- #[cfg_attr(docsrs, doc(cfg(feature = "rt-threaded")))]
+ #[cfg(feature = "sync")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
$item
)*
}
}
-macro_rules! cfg_rt_util {
+macro_rules! cfg_not_sync {
+ ($($item:item)*) => {
+ $( #[cfg(not(feature = "sync"))] $item )*
+ }
+}
+
+macro_rules! cfg_rt {
($($item:item)*) => {
$(
- #[cfg(feature = "rt-util")]
- #[cfg_attr(docsrs, doc(cfg(feature = "rt-util")))]
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
$item
)*
}
}
-macro_rules! cfg_not_rt_threaded {
+macro_rules! cfg_not_rt {
($($item:item)*) => {
- $( #[cfg(not(feature = "rt-threaded"))] $item )*
+ $( #[cfg(not(feature = "rt"))] $item )*
}
}
-macro_rules! cfg_tcp {
+macro_rules! cfg_rt_multi_thread {
($($item:item)*) => {
$(
- #[cfg(feature = "tcp")]
- #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
+ #[cfg(feature = "rt-multi-thread")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))]
$item
)*
}
}
+macro_rules! cfg_not_rt_multi_thread {
+ ($($item:item)*) => {
+ $( #[cfg(not(feature = "rt-multi-thread"))] $item )*
+ }
+}
+
macro_rules! cfg_test_util {
($($item:item)*) => {
$(
@@ -334,36 +305,6 @@ macro_rules! cfg_not_time {
}
}
-macro_rules! cfg_udp {
- ($($item:item)*) => {
- $(
- #[cfg(feature = "udp")]
- #[cfg_attr(docsrs, doc(cfg(feature = "udp")))]
- $item
- )*
- }
-}
-
-macro_rules! cfg_uds {
- ($($item:item)*) => {
- $(
- #[cfg(all(unix, feature = "uds"))]
- #[cfg_attr(docsrs, doc(cfg(feature = "uds")))]
- $item
- )*
- }
-}
-
-macro_rules! cfg_unstable {
- ($($item:item)*) => {
- $(
- #[cfg(tokio_unstable)]
- #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
- $item
- )*
- }
-}
-
macro_rules! cfg_trace {
($($item:item)*) => {
$(
@@ -387,16 +328,15 @@ macro_rules! cfg_coop {
($($item:item)*) => {
$(
#[cfg(any(
- feature = "blocking",
- feature = "dns",
feature = "fs",
- feature = "io-driver",
feature = "io-std",
+ feature = "net",
feature = "process",
- feature = "rt-core",
+ feature = "rt",
+ feature = "signal",
feature = "sync",
feature = "stream",
- feature = "time"
+ feature = "time",
))]
$item
)*
diff --git a/src/macros/mod.rs b/src/macros/mod.rs
index 2643c36..b0af521 100644
--- a/src/macros/mod.rs
+++ b/src/macros/mod.rs
@@ -16,7 +16,7 @@ mod ready;
mod thread_local;
#[macro_use]
-#[cfg(feature = "rt-core")]
+#[cfg(feature = "rt")]
pub(crate) mod scoped_tls;
cfg_macros! {
diff --git a/src/macros/scoped_tls.rs b/src/macros/scoped_tls.rs
index 886f9d4..a00aae2 100644
--- a/src/macros/scoped_tls.rs
+++ b/src/macros/scoped_tls.rs
@@ -23,9 +23,7 @@ macro_rules! scoped_thread_local {
/// Type representing a thread local storage key corresponding to a reference
/// to the type parameter `T`.
pub(crate) struct ScopedKey<T> {
- #[doc(hidden)]
pub(crate) inner: &'static LocalKey<Cell<*const ()>>,
- #[doc(hidden)]
pub(crate) _marker: marker::PhantomData<T>,
}
diff --git a/src/macros/select.rs b/src/macros/select.rs
index 52c8fdd..b63abdd 100644
--- a/src/macros/select.rs
+++ b/src/macros/select.rs
@@ -63,9 +63,9 @@
/// Given that `if` preconditions are used to disable `select!` branches, some
/// caution must be used to avoid missing values.
///
-/// For example, here is **incorrect** usage of `delay` with `if`. The objective
+/// For example, here is **incorrect** usage of `sleep` with `if`. The objective
/// is to repeatedly run an asynchronous task for up to 50 milliseconds.
-/// However, there is a potential for the `delay` completion to be missed.
+/// However, there is a potential for the `sleep` completion to be missed.
///
/// ```no_run
/// use tokio::time::{self, Duration};
@@ -76,11 +76,11 @@
///
/// #[tokio::main]
/// async fn main() {
-/// let mut delay = time::delay_for(Duration::from_millis(50));
+/// let mut sleep = time::sleep(Duration::from_millis(50));
///
-/// while !delay.is_elapsed() {
+/// while !sleep.is_elapsed() {
/// tokio::select! {
-/// _ = &mut delay, if !delay.is_elapsed() => {
+/// _ = &mut sleep, if !sleep.is_elapsed() => {
/// println!("operation timed out");
/// }
/// _ = some_async_work() => {
@@ -91,11 +91,11 @@
/// }
/// ```
///
-/// In the above example, `delay.is_elapsed()` may return `true` even if
-/// `delay.poll()` never returned `Ready`. This opens up a potential race
-/// condition where `delay` expires between the `while !delay.is_elapsed()`
+/// In the above example, `sleep.is_elapsed()` may return `true` even if
+/// `sleep.poll()` never returned `Ready`. This opens up a potential race
+/// condition where `sleep` expires between the `while !sleep.is_elapsed()`
/// check and the call to `select!` resulting in the `some_async_work()` call to
-/// run uninterrupted despite the delay having elapsed.
+/// run uninterrupted despite the sleep having elapsed.
///
/// One way to write the above example without the race would be:
///
@@ -103,17 +103,17 @@
/// use tokio::time::{self, Duration};
///
/// async fn some_async_work() {
-/// # time::delay_for(Duration::from_millis(10)).await;
+/// # time::sleep(Duration::from_millis(10)).await;
/// // do work
/// }
///
/// #[tokio::main]
/// async fn main() {
-/// let mut delay = time::delay_for(Duration::from_millis(50));
+/// let mut sleep = time::sleep(Duration::from_millis(50));
///
/// loop {
/// tokio::select! {
-/// _ = &mut delay => {
+/// _ = &mut sleep => {
/// println!("operation timed out");
/// break;
/// }
@@ -226,7 +226,7 @@
/// #[tokio::main]
/// async fn main() {
/// let mut stream = stream::iter(vec![1, 2, 3]);
-/// let mut delay = time::delay_for(Duration::from_secs(1));
+/// let mut sleep = time::sleep(Duration::from_secs(1));
///
/// loop {
/// tokio::select! {
@@ -237,7 +237,7 @@
/// break;
/// }
/// }
-/// _ = &mut delay => {
+/// _ = &mut sleep => {
/// println!("timeout");
/// break;
/// }
@@ -366,6 +366,7 @@ macro_rules! select {
}
match branch {
$(
+ #[allow(unreachable_code)]
$crate::count!( $($skip)* ) => {
// First, if the future has previously been
// disabled, do not poll it again. This is done
@@ -403,6 +404,7 @@ macro_rules! select {
// The future returned a value, check if matches
// the specified pattern.
#[allow(unused_variables)]
+ #[allow(unused_mut)]
match &out {
$bind => {}
_ => continue,
diff --git a/src/macros/support.rs b/src/macros/support.rs
index fc1cdfc..7f11bc6 100644
--- a/src/macros/support.rs
+++ b/src/macros/support.rs
@@ -1,5 +1,6 @@
cfg_macros! {
- pub use crate::future::{maybe_done, poll_fn};
+ pub use crate::future::poll_fn;
+ pub use crate::future::maybe_done::maybe_done;
pub use crate::util::thread_rng_n;
}
diff --git a/src/net/addr.rs b/src/net/addr.rs
index 5ba898a..7cbe531 100644
--- a/src/net/addr.rs
+++ b/src/net/addr.rs
@@ -9,7 +9,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV
///
/// Implementations of `ToSocketAddrs` for string types require a DNS lookup.
/// These implementations are only provided when Tokio is used with the
-/// **`dns`** feature flag.
+/// **`net`** feature flag.
///
/// # Calling
///
@@ -23,6 +23,15 @@ pub trait ToSocketAddrs: sealed::ToSocketAddrsPriv {}
type ReadyFuture<T> = future::Ready<io::Result<T>>;
+cfg_net! {
+ pub(crate) fn to_socket_addrs<T>(arg: T) -> T::Future
+ where
+ T: ToSocketAddrs,
+ {
+ arg.to_socket_addrs(sealed::Internal)
+ }
+}
+
// ===== impl &impl ToSocketAddrs =====
impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {}
@@ -34,8 +43,8 @@ where
type Iter = T::Iter;
type Future = T::Future;
- fn to_socket_addrs(&self) -> Self::Future {
- (**self).to_socket_addrs()
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ (**self).to_socket_addrs(sealed::Internal)
}
}
@@ -47,7 +56,7 @@ impl sealed::ToSocketAddrsPriv for SocketAddr {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let iter = Some(*self).into_iter();
future::ok(iter)
}
@@ -61,8 +70,8 @@ impl sealed::ToSocketAddrsPriv for SocketAddrV4 {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
- SocketAddr::V4(*self).to_socket_addrs()
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ SocketAddr::V4(*self).to_socket_addrs(sealed::Internal)
}
}
@@ -74,8 +83,8 @@ impl sealed::ToSocketAddrsPriv for SocketAddrV6 {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
- SocketAddr::V6(*self).to_socket_addrs()
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ SocketAddr::V6(*self).to_socket_addrs(sealed::Internal)
}
}
@@ -87,7 +96,7 @@ impl sealed::ToSocketAddrsPriv for (IpAddr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let iter = Some(SocketAddr::from(*self)).into_iter();
future::ok(iter)
}
@@ -101,9 +110,9 @@ impl sealed::ToSocketAddrsPriv for (Ipv4Addr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let (ip, port) = *self;
- SocketAddrV4::new(ip, port).to_socket_addrs()
+ SocketAddrV4::new(ip, port).to_socket_addrs(sealed::Internal)
}
}
@@ -115,9 +124,9 @@ impl sealed::ToSocketAddrsPriv for (Ipv6Addr, u16) {
type Iter = std::option::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let (ip, port) = *self;
- SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
+ SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs(sealed::Internal)
}
}
@@ -129,13 +138,13 @@ impl sealed::ToSocketAddrsPriv for &[SocketAddr] {
type Iter = std::vec::IntoIter<SocketAddr>;
type Future = ReadyFuture<Self::Iter>;
- fn to_socket_addrs(&self) -> Self::Future {
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
let iter = self.to_vec().into_iter();
future::ok(iter)
}
}
-cfg_dns! {
+cfg_net! {
// ===== impl str =====
impl ToSocketAddrs for str {}
@@ -144,23 +153,23 @@ cfg_dns! {
type Iter = sealed::OneOrMore;
type Future = sealed::MaybeReady;
- fn to_socket_addrs(&self) -> Self::Future {
- use crate::runtime::spawn_blocking;
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ use crate::blocking::spawn_blocking;
use sealed::MaybeReady;
// First check if the input parses as a socket address
let res: Result<SocketAddr, _> = self.parse();
if let Ok(addr) = res {
- return MaybeReady::Ready(Some(addr));
+ return MaybeReady(sealed::State::Ready(Some(addr)));
}
// Run DNS lookup on the blocking pool
let s = self.to_owned();
- MaybeReady::Blocking(spawn_blocking(move || {
+ MaybeReady(sealed::State::Blocking(spawn_blocking(move || {
std::net::ToSocketAddrs::to_socket_addrs(&s)
- }))
+ })))
}
}
@@ -172,8 +181,8 @@ cfg_dns! {
type Iter = sealed::OneOrMore;
type Future = sealed::MaybeReady;
- fn to_socket_addrs(&self) -> Self::Future {
- use crate::runtime::spawn_blocking;
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ use crate::blocking::spawn_blocking;
use sealed::MaybeReady;
let (host, port) = *self;
@@ -183,21 +192,34 @@ cfg_dns! {
let addr = SocketAddrV4::new(addr, port);
let addr = SocketAddr::V4(addr);
- return MaybeReady::Ready(Some(addr));
+ return MaybeReady(sealed::State::Ready(Some(addr)));
}
if let Ok(addr) = host.parse::<Ipv6Addr>() {
let addr = SocketAddrV6::new(addr, port, 0, 0);
let addr = SocketAddr::V6(addr);
- return MaybeReady::Ready(Some(addr));
+ return MaybeReady(sealed::State::Ready(Some(addr)));
}
let host = host.to_owned();
- MaybeReady::Blocking(spawn_blocking(move || {
+ MaybeReady(sealed::State::Blocking(spawn_blocking(move || {
std::net::ToSocketAddrs::to_socket_addrs(&(&host[..], port))
- }))
+ })))
+ }
+ }
+
+ // ===== impl (String, u16) =====
+
+ impl ToSocketAddrs for (String, u16) {}
+
+ impl sealed::ToSocketAddrsPriv for (String, u16) {
+ type Iter = sealed::OneOrMore;
+ type Future = sealed::MaybeReady;
+
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ (self.0.as_str(), self.1).to_socket_addrs(sealed::Internal)
}
}
@@ -209,8 +231,8 @@ cfg_dns! {
type Iter = <str as sealed::ToSocketAddrsPriv>::Iter;
type Future = <str as sealed::ToSocketAddrsPriv>::Future;
- fn to_socket_addrs(&self) -> Self::Future {
- (&self[..]).to_socket_addrs()
+ fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future {
+ (&self[..]).to_socket_addrs(sealed::Internal)
}
}
}
@@ -224,27 +246,31 @@ pub(crate) mod sealed {
use std::io;
use std::net::SocketAddr;
- cfg_dns! {
- use crate::task::JoinHandle;
-
- use std::option;
- use std::pin::Pin;
- use std::task::{Context, Poll};
- use std::vec;
- }
-
#[doc(hidden)]
pub trait ToSocketAddrsPriv {
type Iter: Iterator<Item = SocketAddr> + Send + 'static;
type Future: Future<Output = io::Result<Self::Iter>> + Send + 'static;
- fn to_socket_addrs(&self) -> Self::Future;
+ fn to_socket_addrs(&self, internal: Internal) -> Self::Future;
}
- cfg_dns! {
+ #[allow(missing_debug_implementations)]
+ pub struct Internal;
+
+ cfg_net! {
+ use crate::blocking::JoinHandle;
+
+ use std::option;
+ use std::pin::Pin;
+ use std::task::{Context, Poll};
+ use std::vec;
+
#[doc(hidden)]
#[derive(Debug)]
- pub enum MaybeReady {
+ pub struct MaybeReady(pub(super) State);
+
+ #[derive(Debug)]
+ pub(super) enum State {
Ready(Option<SocketAddr>),
Blocking(JoinHandle<io::Result<vec::IntoIter<SocketAddr>>>),
}
@@ -260,12 +286,12 @@ pub(crate) mod sealed {
type Output = io::Result<OneOrMore>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- match *self {
- MaybeReady::Ready(ref mut i) => {
+ match self.0 {
+ State::Ready(ref mut i) => {
let iter = OneOrMore::One(i.take().into_iter());
Poll::Ready(Ok(iter))
}
- MaybeReady::Blocking(ref mut rx) => {
+ State::Blocking(ref mut rx) => {
let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
Poll::Ready(res)
diff --git a/src/net/lookup_host.rs b/src/net/lookup_host.rs
index 3098b46..2886184 100644
--- a/src/net/lookup_host.rs
+++ b/src/net/lookup_host.rs
@@ -1,5 +1,5 @@
-cfg_dns! {
- use crate::net::addr::ToSocketAddrs;
+cfg_net! {
+ use crate::net::addr::{self, ToSocketAddrs};
use std::io;
use std::net::SocketAddr;
@@ -33,6 +33,6 @@ cfg_dns! {
where
T: ToSocketAddrs
{
- host.to_socket_addrs().await
+ addr::to_socket_addrs(host).await
}
}
diff --git a/src/net/mod.rs b/src/net/mod.rs
index eb24ac0..b7365e6 100644
--- a/src/net/mod.rs
+++ b/src/net/mod.rs
@@ -23,27 +23,26 @@
//! [`UnixDatagram`]: UnixDatagram
mod addr;
+#[cfg(feature = "net")]
+pub(crate) use addr::to_socket_addrs;
pub use addr::ToSocketAddrs;
-cfg_dns! {
+cfg_net! {
mod lookup_host;
pub use lookup_host::lookup_host;
-}
-cfg_tcp! {
pub mod tcp;
pub use tcp::listener::TcpListener;
+ pub use tcp::socket::TcpSocket;
pub use tcp::stream::TcpStream;
-}
-cfg_udp! {
pub mod udp;
pub use udp::socket::UdpSocket;
}
-cfg_uds! {
+cfg_net_unix! {
pub mod unix;
- pub use unix::datagram::UnixDatagram;
+ pub use unix::datagram::socket::UnixDatagram;
pub use unix::listener::UnixListener;
pub use unix::stream::UnixStream;
}
diff --git a/src/net/tcp/incoming.rs b/src/net/tcp/incoming.rs
deleted file mode 100644
index 062be1e..0000000
--- a/src/net/tcp/incoming.rs
+++ /dev/null
@@ -1,42 +0,0 @@
-use crate::net::tcp::{TcpListener, TcpStream};
-
-use std::io;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-/// Stream returned by the `TcpListener::incoming` function representing the
-/// stream of sockets received from a listener.
-#[must_use = "streams do nothing unless polled"]
-#[derive(Debug)]
-pub struct Incoming<'a> {
- inner: &'a mut TcpListener,
-}
-
-impl Incoming<'_> {
- pub(crate) fn new(listener: &mut TcpListener) -> Incoming<'_> {
- Incoming { inner: listener }
- }
-
- /// Attempts to poll `TcpStream` by polling inner `TcpListener` to accept
- /// connection.
- ///
- /// If `TcpListener` isn't ready yet, `Poll::Pending` is returned and
- /// current task will be notified by a waker.
- pub fn poll_accept(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<TcpStream>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Ok(socket))
- }
-}
-
-#[cfg(feature = "stream")]
-impl crate::stream::Stream for Incoming<'_> {
- type Item = io::Result<TcpStream>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Some(Ok(socket)))
- }
-}
diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs
index fd79b25..3f9bca0 100644
--- a/src/net/tcp/listener.rs
+++ b/src/net/tcp/listener.rs
@@ -1,7 +1,6 @@
-use crate::future::poll_fn;
use crate::io::PollEvented;
-use crate::net::tcp::{Incoming, TcpStream};
-use crate::net::ToSocketAddrs;
+use crate::net::tcp::TcpStream;
+use crate::net::{to_socket_addrs, ToSocketAddrs};
use std::convert::TryFrom;
use std::fmt;
@@ -9,7 +8,7 @@ use std::io;
use std::net::{self, SocketAddr};
use std::task::{Context, Poll};
-cfg_tcp! {
+cfg_net! {
/// A TCP socket server, listening for connections.
///
/// You can accept a new connection by using the [`accept`](`TcpListener::accept`) method. Alternatively `TcpListener`
@@ -40,7 +39,7 @@ cfg_tcp! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
@@ -81,7 +80,7 @@ impl TcpListener {
/// method.
///
/// The address type can be any implementor of the [`ToSocketAddrs`] trait.
- /// Note that strings only implement this trait when the **`dns`** feature
+ /// Note that strings only implement this trait when the **`net`** feature
/// is enabled, as strings may contain domain names that need to be resolved.
///
/// If `addr` yields multiple addresses, bind will be attempted with each of
@@ -110,27 +109,8 @@ impl TcpListener {
/// Ok(())
/// }
/// ```
- ///
- /// Without the `dns` feature:
- ///
- /// ```no_run
- /// use tokio::net::TcpListener;
- /// use std::net::Ipv4Addr;
- ///
- /// use std::io;
- ///
- /// #[tokio::main]
- /// async fn main() -> io::Result<()> {
- /// let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 2345)).await?;
- ///
- /// // use the listener
- ///
- /// # let _ = listener;
- /// Ok(())
- /// }
- /// ```
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
- let addrs = addr.to_socket_addrs().await?;
+ let addrs = to_socket_addrs(addr).await?;
let mut last_err = None;
@@ -150,7 +130,7 @@ impl TcpListener {
}
fn bind_addr(addr: SocketAddr) -> io::Result<TcpListener> {
- let listener = mio::net::TcpListener::bind(&addr)?;
+ let listener = mio::net::TcpListener::bind(addr)?;
TcpListener::new(listener)
}
@@ -171,7 +151,7 @@ impl TcpListener {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener.accept().await {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
@@ -181,66 +161,53 @@ impl TcpListener {
/// Ok(())
/// }
/// ```
- pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> {
- poll_fn(|cx| self.poll_accept(cx)).await
+ pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
+ let (mio, addr) = self
+ .io
+ .async_io(mio::Interest::READABLE, |sock| sock.accept())
+ .await?;
+
+ let stream = TcpStream::new(mio)?;
+ Ok((stream, addr))
}
- /// Attempts to poll `SocketAddr` and `TcpStream` bound to this address.
+ /// Polls to accept a new incoming connection to this listener.
///
- /// In case if I/O resource isn't ready yet, `Poll::Pending` is returned and
+ /// If there is no connection to accept, `Poll::Pending` is returned and the
/// current task will be notified by a waker.
- pub fn poll_accept(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
- let (io, addr) = ready!(self.poll_accept_std(cx))?;
-
- let io = mio::net::TcpStream::from_stream(io)?;
- let io = TcpStream::new(io)?;
-
- Poll::Ready(Ok((io, addr)))
- }
-
- fn poll_accept_std(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<(net::TcpStream, SocketAddr)>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
+ ///
+ /// When ready, the most recent task that called `poll_accept` is notified.
+ /// The caller is responsble to ensure that `poll_accept` is called from a
+ /// single task. Failing to do this could result in tasks hanging.
+ pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
+ loop {
+ let ev = ready!(self.io.poll_read_ready(cx))?;
- match self.io.get_ref().accept_std() {
- Ok(pair) => Poll::Ready(Ok(pair)),
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
+ match self.io.get_ref().accept() {
+ Ok((io, addr)) => {
+ let io = TcpStream::new(io)?;
+ return Poll::Ready(Ok((io, addr)));
+ }
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ Err(e) => return Poll::Ready(Err(e)),
}
- Err(e) => Poll::Ready(Err(e)),
}
}
/// Creates a new TCP listener from the standard library's TCP listener.
///
- /// This method can be used when the `Handle::tcp_listen` method isn't
- /// sufficient because perhaps some more configuration is needed in terms of
- /// before the calls to `bind` and `listen`.
+ /// This function is intended to be used to wrap a TCP listener from the
+ /// standard library in the Tokio equivalent. The conversion assumes nothing
+ /// about the underlying listener; it is left up to the user to set it in
+ /// non-blocking mode.
///
/// This API is typically paired with the `net2` crate and the `TcpBuilder`
/// type to build up and customize a listener before it's shipped off to the
/// backing event loop. This allows configuration of options like
/// `SO_REUSEPORT`, binding to multiple addresses, etc.
///
- /// The `addr` argument here is one of the addresses that `listener` is
- /// bound to and the listener will only be guaranteed to accept connections
- /// of the same address type currently.
- ///
- /// The platform specific behavior of this function looks like:
- ///
- /// * On Unix, the socket is placed into nonblocking mode and connections
- /// can be accepted as normal
- ///
- /// * On Windows, the address is stored internally and all future accepts
- /// will only be for the same IP version as `addr` specified. That is, if
- /// `addr` is an IPv4 address then all sockets accepted will be IPv4 as
- /// well (same for IPv6).
///
/// # Examples
///
@@ -262,14 +229,14 @@ impl TcpListener {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
- let io = mio::net::TcpListener::from_std(listener)?;
+ let io = mio::net::TcpListener::from_std(listener);
let io = PollEvented::new(io)?;
Ok(TcpListener { io })
}
- fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> {
+ pub(crate) fn new(listener: mio::net::TcpListener) -> io::Result<TcpListener> {
let io = PollEvented::new(listener)?;
Ok(TcpListener { io })
}
@@ -301,46 +268,6 @@ impl TcpListener {
self.io.get_ref().local_addr()
}
- /// Returns a stream over the connections being received on this listener.
- ///
- /// Note that `TcpListener` also directly implements `Stream`.
- ///
- /// The returned stream will never return `None` and will also not yield the
- /// peer's `SocketAddr` structure. Iterating over it is equivalent to
- /// calling accept in a loop.
- ///
- /// # Errors
- ///
- /// Note that accepting a connection can lead to various errors and not all
- /// of them are necessarily fatal ‒ for example having too many open file
- /// descriptors or the other side closing the connection while it waits in
- /// an accept queue. These would terminate the stream if not handled in any
- /// way.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::{net::TcpListener, stream::StreamExt};
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
- /// let mut incoming = listener.incoming();
- ///
- /// while let Some(stream) = incoming.next().await {
- /// match stream {
- /// Ok(stream) => {
- /// println!("new client!");
- /// }
- /// Err(e) => { /* connection failed */ }
- /// }
- /// }
- /// }
- /// ```
- pub fn incoming(&mut self) -> Incoming<'_> {
- Incoming::new(self)
- }
-
/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
@@ -398,29 +325,12 @@ impl TcpListener {
impl crate::stream::Stream for TcpListener {
type Item = io::Result<TcpStream>;
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Self::Item>> {
+ fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, _) = ready!(self.poll_accept(cx))?;
Poll::Ready(Some(Ok(socket)))
}
}
-impl TryFrom<TcpListener> for mio::net::TcpListener {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
- ///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
- ///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: TcpListener) -> Result<Self, Self::Error> {
- value.io.into_inner()
- }
-}
-
impl TryFrom<net::TcpListener> for TcpListener {
type Error = io::Error;
@@ -453,14 +363,12 @@ mod sys {
#[cfg(windows)]
mod sys {
- // TODO: let's land these upstream with mio and then we can add them here.
- //
- // use std::os::windows::prelude::*;
- // use super::{TcpListener;
- //
- // impl AsRawHandle for TcpListener {
- // fn as_raw_handle(&self) -> RawHandle {
- // self.listener.io().as_raw_handle()
- // }
- // }
+ use super::TcpListener;
+ use std::os::windows::prelude::*;
+
+ impl AsRawSocket for TcpListener {
+ fn as_raw_socket(&self) -> RawSocket {
+ self.io.get_ref().as_raw_socket()
+ }
+ }
}
diff --git a/src/net/tcp/mod.rs b/src/net/tcp/mod.rs
index 7ad36eb..7f0f6d9 100644
--- a/src/net/tcp/mod.rs
+++ b/src/net/tcp/mod.rs
@@ -1,10 +1,8 @@
//! TCP utility types
pub(crate) mod listener;
-pub(crate) use listener::TcpListener;
-mod incoming;
-pub use incoming::Incoming;
+pub(crate) mod socket;
mod split;
pub use split::{ReadHalf, WriteHalf};
diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs
new file mode 100644
index 0000000..5b0f802
--- /dev/null
+++ b/src/net/tcp/socket.rs
@@ -0,0 +1,349 @@
+use crate::net::{TcpListener, TcpStream};
+
+use std::fmt;
+use std::io;
+use std::net::SocketAddr;
+
+#[cfg(unix)]
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+#[cfg(windows)]
+use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
+
+/// A TCP socket that has not yet been converted to a `TcpStream` or
+/// `TcpListener`.
+///
+/// `TcpSocket` wraps an operating system socket and enables the caller to
+/// configure the socket before establishing a TCP connection or accepting
+/// inbound connections. The caller is able to set socket option and explicitly
+/// bind the socket with a socket address.
+///
+/// The underlying socket is closed when the `TcpSocket` value is dropped.
+///
+/// `TcpSocket` should only be used directly if the default configuration used
+/// by `TcpStream::connect` and `TcpListener::bind` does not meet the required
+/// use case.
+///
+/// Calling `TcpStream::connect("127.0.0.1:8080")` is equivalent to:
+///
+/// ```no_run
+/// use tokio::net::TcpSocket;
+///
+/// use std::io;
+///
+/// #[tokio::main]
+/// async fn main() -> io::Result<()> {
+/// let addr = "127.0.0.1:8080".parse().unwrap();
+///
+/// let socket = TcpSocket::new_v4()?;
+/// let stream = socket.connect(addr).await?;
+/// # drop(stream);
+///
+/// Ok(())
+/// }
+/// ```
+///
+/// Calling `TcpListener::bind("127.0.0.1:8080")` is equivalent to:
+///
+/// ```no_run
+/// use tokio::net::TcpSocket;
+///
+/// use std::io;
+///
+/// #[tokio::main]
+/// async fn main() -> io::Result<()> {
+/// let addr = "127.0.0.1:8080".parse().unwrap();
+///
+/// let socket = TcpSocket::new_v4()?;
+/// // On platforms with Berkeley-derived sockets, this allows to quickly
+/// // rebind a socket, without needing to wait for the OS to clean up the
+/// // previous one.
+/// //
+/// // On Windows, this allows rebinding sockets which are actively in use,
+/// // which allows “socket hijacking”, so we explicitly don't set it here.
+/// // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
+/// socket.set_reuseaddr(true)?;
+/// socket.bind(addr)?;
+///
+/// let listener = socket.listen(1024)?;
+/// # drop(listener);
+///
+/// Ok(())
+/// }
+/// ```
+///
+/// Setting socket options not explicitly provided by `TcpSocket` may be done by
+/// accessing the `RawFd`/`RawSocket` using [`AsRawFd`]/[`AsRawSocket`] and
+/// setting the option with a crate like [`socket2`].
+///
+/// [`RawFd`]: https://doc.rust-lang.org/std/os/unix/io/type.RawFd.html
+/// [`RawSocket`]: https://doc.rust-lang.org/std/os/windows/io/type.RawSocket.html
+/// [`AsRawFd`]: https://doc.rust-lang.org/std/os/unix/io/trait.AsRawFd.html
+/// [`AsRawSocket`]: https://doc.rust-lang.org/std/os/windows/io/trait.AsRawSocket.html
+/// [`socket2`]: https://docs.rs/socket2/
+pub struct TcpSocket {
+ inner: mio::net::TcpSocket,
+}
+
+impl TcpSocket {
+ /// Create a new socket configured for IPv4.
+ ///
+ /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`.
+ ///
+ /// # Returns
+ ///
+ /// On success, the newly created `TcpSocket` is returned. If an error is
+ /// encountered, it is returned instead.
+ ///
+ /// # Examples
+ ///
+ /// Create a new IPv4 socket and start listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(128)?;
+ /// # drop(listener);
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn new_v4() -> io::Result<TcpSocket> {
+ let inner = mio::net::TcpSocket::new_v4()?;
+ Ok(TcpSocket { inner })
+ }
+
+ /// Create a new socket configured for IPv6.
+ ///
+ /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`.
+ ///
+ /// # Returns
+ ///
+ /// On success, the newly created `TcpSocket` is returned. If an error is
+ /// encountered, it is returned instead.
+ ///
+ /// # Examples
+ ///
+ /// Create a new IPv6 socket and start listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "[::1]:8080".parse().unwrap();
+ /// let socket = TcpSocket::new_v6()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(128)?;
+ /// # drop(listener);
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn new_v6() -> io::Result<TcpSocket> {
+ let inner = mio::net::TcpSocket::new_v6()?;
+ Ok(TcpSocket { inner })
+ }
+
+ /// Allow the socket to bind to an in-use address.
+ ///
+ /// Behavior is platform specific. Refer to the target platform's
+ /// documentation for more details.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.set_reuseaddr(true)?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
+ self.inner.set_reuseaddr(reuseaddr)
+ }
+
+ /// Bind the socket to the given address.
+ ///
+ /// This calls the `bind(2)` operating-system function. Behavior is
+ /// platform specific. Refer to the target platform's documentation for more
+ /// details.
+ ///
+ /// # Examples
+ ///
+ /// Bind a socket before listening.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
+ self.inner.bind(addr)
+ }
+
+ /// Establish a TCP connection with a peer at the specified socket address.
+ ///
+ /// The `TcpSocket` is consumed. Once the connection is established, a
+ /// connected [`TcpStream`] is returned. If the connection fails, the
+ /// encountered error is returned.
+ ///
+ /// [`TcpStream`]: TcpStream
+ ///
+ /// This calls the `connect(2)` operating-system function. Behavior is
+ /// platform specific. Refer to the target platform's documentation for more
+ /// details.
+ ///
+ /// # Examples
+ ///
+ /// Connecting to a peer.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// let stream = socket.connect(addr).await?;
+ /// # drop(stream);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
+ let mio = self.inner.connect(addr)?;
+ TcpStream::connect_mio(mio).await
+ }
+
+ /// Convert the socket into a `TcpListener`.
+ ///
+ /// `backlog` defines the maximum number of pending connections are queued
+ /// by the operating system at any given time. Connection are removed from
+ /// the queue with [`TcpListener::accept`]. When the queue is full, the
+ /// operationg-system will start rejecting connections.
+ ///
+ /// [`TcpListener::accept`]: TcpListener::accept
+ ///
+ /// This calls the `listen(2)` operating-system function, marking the socket
+ /// as a passive socket. Behavior is platform specific. Refer to the target
+ /// platform's documentation for more details.
+ ///
+ /// # Examples
+ ///
+ /// Create a `TcpListener`.
+ ///
+ /// ```no_run
+ /// use tokio::net::TcpSocket;
+ ///
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let addr = "127.0.0.1:8080".parse().unwrap();
+ ///
+ /// let socket = TcpSocket::new_v4()?;
+ /// socket.bind(addr)?;
+ ///
+ /// let listener = socket.listen(1024)?;
+ /// # drop(listener);
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
+ let mio = self.inner.listen(backlog)?;
+ TcpListener::new(mio)
+ }
+}
+
+impl fmt::Debug for TcpSocket {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.inner.fmt(fmt)
+ }
+}
+
+#[cfg(unix)]
+impl AsRawFd for TcpSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.inner.as_raw_fd()
+ }
+}
+
+#[cfg(unix)]
+impl FromRawFd for TcpSocket {
+ /// Converts a `RawFd` to a `TcpSocket`.
+ ///
+ /// # Notes
+ ///
+ /// The caller is responsible for ensuring that the socket is in
+ /// non-blocking mode.
+ unsafe fn from_raw_fd(fd: RawFd) -> TcpSocket {
+ let inner = mio::net::TcpSocket::from_raw_fd(fd);
+ TcpSocket { inner }
+ }
+}
+
+#[cfg(windows)]
+impl IntoRawSocket for TcpSocket {
+ fn into_raw_socket(self) -> RawSocket {
+ self.inner.into_raw_socket()
+ }
+}
+
+#[cfg(windows)]
+impl AsRawSocket for TcpSocket {
+ fn as_raw_socket(&self) -> RawSocket {
+ self.inner.as_raw_socket()
+ }
+}
+
+#[cfg(windows)]
+impl FromRawSocket for TcpSocket {
+ /// Converts a `RawSocket` to a `TcpStream`.
+ ///
+ /// # Notes
+ ///
+ /// The caller is responsible for ensuring that the socket is in
+ /// non-blocking mode.
+ unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket {
+ let inner = mio::net::TcpSocket::from_raw_socket(socket);
+ TcpSocket { inner }
+ }
+}
diff --git a/src/net/tcp/split.rs b/src/net/tcp/split.rs
index 469056a..9a257f8 100644
--- a/src/net/tcp/split.rs
+++ b/src/net/tcp/split.rs
@@ -9,17 +9,15 @@
//! level.
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
-use bytes::Buf;
use std::io;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
-/// Read half of a [`TcpStream`], created by [`split`].
+/// Borrowed read half of a [`TcpStream`], created by [`split`].
///
/// Reading from a `ReadHalf` is usually done using the convenience methods found on the
/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude].
@@ -31,12 +29,12 @@ use std::task::{Context, Poll};
#[derive(Debug)]
pub struct ReadHalf<'a>(&'a TcpStream);
-/// Write half of a [`TcpStream`], created by [`split`].
+/// Borrowed write half of a [`TcpStream`], created by [`split`].
///
/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will
/// shut down the TCP stream in the write direction.
///
-/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
+/// Writing to an `WriteHalf` is usually done using the convenience methods found
/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude].
///
/// [`TcpStream`]: TcpStream
@@ -83,7 +81,7 @@ impl ReadHalf<'_> {
///
/// [`TcpStream::poll_peek`]: TcpStream::poll_peek
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
- self.0.poll_peek2(cx, buf)
+ self.0.poll_peek(cx, buf)
}
/// Receives data on the socket from the remote address to which it is
@@ -131,15 +129,11 @@ impl ReadHalf<'_> {
}
impl AsyncRead for ReadHalf<'_> {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.0.poll_read_priv(cx, buf)
}
}
@@ -153,14 +147,6 @@ impl AsyncWrite for WriteHalf<'_> {
self.0.poll_write_priv(cx, buf)
}
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- self.0.poll_write_buf_priv(cx, buf)
- }
-
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs
index 3f6ee33..4b4e263 100644
--- a/src/net/tcp/split_owned.rs
+++ b/src/net/tcp/split_owned.rs
@@ -9,12 +9,10 @@
//! level.
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
-use bytes::Buf;
use std::error::Error;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::sync::Arc;
@@ -37,10 +35,9 @@ pub struct OwnedReadHalf {
/// Owned write half of a [`TcpStream`], created by [`into_split`].
///
-/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will
-/// shut down the TCP stream in the write direction.
-///
-/// Dropping the write half will shutdown the write half of the TCP stream.
+/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
+/// shut down the TCP stream in the write direction. Dropping the write half
+/// will also shut down the write half of the TCP stream.
///
/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude].
@@ -77,13 +74,13 @@ pub(crate) fn reunite(
write.forget();
// This unwrap cannot fail as the api does not allow creating more than two Arcs,
// and we just dropped the other half.
- Ok(Arc::try_unwrap(read.inner).expect("Too many handles to Arc"))
+ Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
} else {
Err(ReuniteError(read, write))
}
}
-/// Error indicating two halves were not from the same socket, and thus could
+/// Error indicating that two halves were not from the same socket, and thus could
/// not be reunited.
#[derive(Debug)]
pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
@@ -139,7 +136,7 @@ impl OwnedReadHalf {
///
/// [`TcpStream::poll_peek`]: TcpStream::poll_peek
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
- self.inner.poll_peek2(cx, buf)
+ self.inner.poll_peek(cx, buf)
}
/// Receives data on the socket from the remote address to which it is
@@ -187,15 +184,11 @@ impl OwnedReadHalf {
}
impl AsyncRead for OwnedReadHalf {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.inner.poll_read_priv(cx, buf)
}
}
@@ -210,7 +203,9 @@ impl OwnedWriteHalf {
reunite(other, self)
}
- /// Drop the write half, but don't issue a TCP shutdown.
+ /// Destroy the write half, but don't close the write half of the stream
+ /// until the read half is dropped. If the read half has already been
+ /// dropped, this closes the stream.
pub fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
@@ -234,14 +229,6 @@ impl AsyncWrite for OwnedWriteHalf {
self.inner.poll_write_priv(cx, buf)
}
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- self.inner.poll_write_buf_priv(cx, buf)
- }
-
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
@@ -250,7 +237,11 @@ impl AsyncWrite for OwnedWriteHalf {
// `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
- self.inner.shutdown(Shutdown::Write).into()
+ let res = self.inner.shutdown(Shutdown::Write);
+ if res.is_ok() {
+ Pin::into_inner(self).shutdown_on_drop = false;
+ }
+ res.into()
}
}
diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs
index cc81e11..f90e9a3 100644
--- a/src/net/tcp/stream.rs
+++ b/src/net/tcp/stream.rs
@@ -1,21 +1,17 @@
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite, PollEvented};
+use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf};
use crate::net::tcp::split::{split, ReadHalf, WriteHalf};
use crate::net::tcp::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
-use crate::net::ToSocketAddrs;
+use crate::net::{to_socket_addrs, ToSocketAddrs};
-use bytes::Buf;
-use iovec::IoVec;
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write};
-use std::mem::MaybeUninit;
-use std::net::{self, Shutdown, SocketAddr};
+use std::net::{Shutdown, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
-use std::time::Duration;
-cfg_tcp! {
+cfg_net! {
/// A TCP stream between a local and a remote socket.
///
/// A TCP stream can either be created by connecting to an endpoint, via the
@@ -26,8 +22,8 @@ cfg_tcp! {
/// traits. Examples import these traits through [the prelude].
///
/// [`connect`]: method@TcpStream::connect
- /// [accepting]: method@super::TcpListener::accept
- /// [listener]: struct@super::TcpListener
+ /// [accepting]: method@crate::net::TcpListener::accept
+ /// [listener]: struct@crate::net::TcpListener
/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
/// [the prelude]: crate::prelude
@@ -65,7 +61,7 @@ impl TcpStream {
///
/// `addr` is an address of the remote host. Anything which implements the
/// [`ToSocketAddrs`] trait can be supplied as the address. Note that
- /// strings only implement this trait when the **`dns`** feature is enabled,
+ /// strings only implement this trait when the **`net`** feature is enabled,
/// as strings may contain domain names that need to be resolved.
///
/// If `addr` yields multiple addresses, connect will be attempted with each
@@ -94,32 +90,12 @@ impl TcpStream {
/// }
/// ```
///
- /// Without the `dns` feature:
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- /// use tokio::prelude::*;
- /// use std::error::Error;
- /// use std::net::Ipv4Addr;
- ///
- /// #[tokio::main]
- /// async fn main() -> Result<(), Box<dyn Error>> {
- /// // Connect to a peer
- /// let mut stream = TcpStream::connect((Ipv4Addr::new(127, 0, 0, 1), 8080)).await?;
- ///
- /// // Write some data.
- /// stream.write_all(b"hello world!").await?;
- ///
- /// Ok(())
- /// }
- /// ```
- ///
/// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait.
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
- let addrs = addr.to_socket_addrs().await?;
+ let addrs = to_socket_addrs(addr).await?;
let mut last_err = None;
@@ -140,7 +116,11 @@ impl TcpStream {
/// Establishes a connection to the specified `addr`.
async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> {
- let sys = mio::net::TcpStream::connect(&addr)?;
+ let sys = mio::net::TcpStream::connect(addr)?;
+ TcpStream::connect_mio(sys).await
+ }
+
+ pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> {
let stream = TcpStream::new(sys)?;
// Once we've connected, wait for the stream to be writable as
@@ -188,37 +168,13 @@ impl TcpStream {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
- let io = mio::net::TcpStream::from_stream(stream)?;
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ pub fn from_std(stream: std::net::TcpStream) -> io::Result<TcpStream> {
+ let io = mio::net::TcpStream::from_std(stream);
let io = PollEvented::new(io)?;
Ok(TcpStream { io })
}
- // Connects `TcpStream` asynchronously that may be built with a net2 `TcpBuilder`.
- //
- // This should be removed in favor of some in-crate TcpSocket builder API.
- #[doc(hidden)]
- pub async fn connect_std(stream: net::TcpStream, addr: &SocketAddr) -> io::Result<TcpStream> {
- let io = mio::net::TcpStream::connect_stream(stream, addr)?;
- let io = PollEvented::new(io)?;
- let stream = TcpStream { io };
-
- // Once we've connected, wait for the stream to be writable as
- // that's when the actual connection has been initiated. Once we're
- // writable we check for `take_socket_error` to see if the connect
- // actually hit an error or not.
- //
- // If all that succeeded then we ship everything on up.
- poll_fn(|cx| stream.io.poll_write_ready(cx)).await?;
-
- if let Some(e) = stream.io.get_ref().take_error()? {
- return Err(e);
- }
-
- Ok(stream)
- }
-
/// Returns the local address that this stream is bound to.
///
/// # Examples
@@ -281,7 +237,7 @@ impl TcpStream {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut stream = TcpStream::connect("127.0.0.1:8000").await?;
+ /// let stream = TcpStream::connect("127.0.0.1:8000").await?;
/// let mut buf = [0; 10];
///
/// poll_fn(|cx| {
@@ -291,24 +247,17 @@ impl TcpStream {
/// Ok(())
/// }
/// ```
- pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
- self.poll_peek2(cx, buf)
- }
-
- pub(super) fn poll_peek2(
- &self,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().peek(buf) {
- Ok(ret) => Poll::Ready(Ok(ret)),
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
+ pub fn poll_peek(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
+ loop {
+ let ev = ready!(self.io.poll_read_ready(cx))?;
+
+ match self.io.get_ref().peek(buf) {
+ Ok(ret) => return Poll::Ready(Ok(ret)),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ Err(e) => return Poll::Ready(Err(e)),
}
- Err(e) => Poll::Ready(Err(e)),
}
}
@@ -349,8 +298,10 @@ impl TcpStream {
///
/// [`read`]: fn@crate::io::AsyncReadExt::read
/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
- pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- poll_fn(|cx| self.poll_peek(cx, buf)).await
+ pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::READABLE, |io| io.peek(buf))
+ .await
}
/// Shuts down the read, write, or both halves of this connection.
@@ -427,144 +378,6 @@ impl TcpStream {
self.io.get_ref().set_nodelay(nodelay)
}
- /// Gets the value of the `SO_RCVBUF` option on this socket.
- ///
- /// For more information about this option, see [`set_recv_buffer_size`].
- ///
- /// [`set_recv_buffer_size`]: TcpStream::set_recv_buffer_size
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// println!("{:?}", stream.recv_buffer_size()?);
- /// # Ok(())
- /// # }
- /// ```
- pub fn recv_buffer_size(&self) -> io::Result<usize> {
- self.io.get_ref().recv_buffer_size()
- }
-
- /// Sets the value of the `SO_RCVBUF` option on this socket.
- ///
- /// Changes the size of the operating system's receive buffer associated
- /// with the socket.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// stream.set_recv_buffer_size(100)?;
- /// # Ok(())
- /// # }
- /// ```
- pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
- self.io.get_ref().set_recv_buffer_size(size)
- }
-
- /// Gets the value of the `SO_SNDBUF` option on this socket.
- ///
- /// For more information about this option, see [`set_send_buffer_size`].
- ///
- /// [`set_send_buffer_size`]: TcpStream::set_send_buffer_size
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// println!("{:?}", stream.send_buffer_size()?);
- /// # Ok(())
- /// # }
- /// ```
- pub fn send_buffer_size(&self) -> io::Result<usize> {
- self.io.get_ref().send_buffer_size()
- }
-
- /// Sets the value of the `SO_SNDBUF` option on this socket.
- ///
- /// Changes the size of the operating system's send buffer associated with
- /// the socket.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// stream.set_send_buffer_size(100)?;
- /// # Ok(())
- /// # }
- /// ```
- pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
- self.io.get_ref().set_send_buffer_size(size)
- }
-
- /// Returns whether keepalive messages are enabled on this socket, and if so
- /// the duration of time between them.
- ///
- /// For more information about this option, see [`set_keepalive`].
- ///
- /// [`set_keepalive`]: TcpStream::set_keepalive
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// println!("{:?}", stream.keepalive()?);
- /// # Ok(())
- /// # }
- /// ```
- pub fn keepalive(&self) -> io::Result<Option<Duration>> {
- self.io.get_ref().keepalive()
- }
-
- /// Sets whether keepalive messages are enabled to be sent on this socket.
- ///
- /// On Unix, this option will set the `SO_KEEPALIVE` as well as the
- /// `TCP_KEEPALIVE` or `TCP_KEEPIDLE` option (depending on your platform).
- /// On Windows, this will set the `SIO_KEEPALIVE_VALS` option.
- ///
- /// If `None` is specified then keepalive messages are disabled, otherwise
- /// the duration specified will be the time to remain idle before sending a
- /// TCP keepalive probe.
- ///
- /// Some platforms specify this value in seconds, so sub-second
- /// specifications may be omitted.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// stream.set_keepalive(None)?;
- /// # Ok(())
- /// # }
- /// ```
- pub fn set_keepalive(&self, keepalive: Option<Duration>) -> io::Result<()> {
- self.io.get_ref().set_keepalive(keepalive)
- }
-
/// Gets the value of the `IP_TTL` option for this socket.
///
/// For more information about this option, see [`set_ttl`].
@@ -608,57 +421,9 @@ impl TcpStream {
self.io.get_ref().set_ttl(ttl)
}
- /// Reads the linger duration for this socket by getting the `SO_LINGER`
- /// option.
- ///
- /// For more information about this option, see [`set_linger`].
- ///
- /// [`set_linger`]: TcpStream::set_linger
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// println!("{:?}", stream.linger()?);
- /// # Ok(())
- /// # }
- /// ```
- pub fn linger(&self) -> io::Result<Option<Duration>> {
- self.io.get_ref().linger()
- }
-
- /// Sets the linger duration of this socket by setting the `SO_LINGER`
- /// option.
- ///
- /// This option controls the action taken when a stream has unsent messages
- /// and the stream is closed. If `SO_LINGER` is set, the system
- /// shall block the process until it can transmit the data or until the
- /// time expires.
- ///
- /// If `SO_LINGER` is not specified, and the stream is closed, the system
- /// handles the call in a way that allows the process to continue as quickly
- /// as possible.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::TcpStream;
- ///
- /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> {
- /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
- ///
- /// stream.set_linger(None)?;
- /// # Ok(())
- /// # }
- /// ```
- pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
- self.io.get_ref().set_linger(dur)
- }
-
+ // These lifetime markers also appear in the generated documentation, and make
+ // it more clear that this is a *borrowed* split.
+ #[allow(clippy::needless_lifetimes)]
/// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
@@ -666,7 +431,7 @@ impl TcpStream {
/// moved into independently spawned tasks.
///
/// [`into_split`]: TcpStream::into_split()
- pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
+ pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) {
split(self)
}
@@ -676,10 +441,11 @@ impl TcpStream {
/// Unlike [`split`], the owned halves can be moved to separate tasks, however
/// this comes at the cost of a heap allocation.
///
- /// **Note::** Dropping the write half will shutdown the write half of the TCP
- /// stream. This is equivalent to calling `shutdown(Write)` on the `TcpStream`.
+ /// **Note:** Dropping the write half will shut down the write half of the TCP
+ /// stream. This is equivalent to calling [`shutdown(Write)`] on the `TcpStream`.
///
/// [`split`]: TcpStream::split()
+ /// [`shutdown(Write)`]: fn@crate::net::TcpStream::shutdown
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
split_owned(self)
}
@@ -698,16 +464,30 @@ impl TcpStream {
pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().read(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ loop {
+ let ev = ready!(self.io.poll_read_ready(cx))?;
+
+ // Safety: `TcpStream::read` will not peek at the maybe uinitialized bytes.
+ let b = unsafe {
+ &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
+ };
+ match self.io.get_ref().read(b) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ Ok(n) => {
+ // Safety: We trust `TcpStream::read` to have filled up `n` bytes
+ // in the buffer.
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.advance(n);
+ return Poll::Ready(Ok(()));
+ }
+ Err(e) => return Poll::Ready(Err(e)),
}
- x => Poll::Ready(x),
}
}
@@ -716,143 +496,27 @@ impl TcpStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().write(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
+ loop {
+ let ev = ready!(self.io.poll_write_ready(cx))?;
+
+ match self.io.get_ref().write(buf) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ x => return Poll::Ready(x),
}
- x => Poll::Ready(x),
- }
- }
-
- pub(super) fn poll_write_buf_priv<B: Buf>(
- &self,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- use std::io::IoSlice;
-
- ready!(self.io.poll_write_ready(cx))?;
-
- // The `IoVec` (v0.1.x) type can't have a zero-length size, so create
- // a dummy version from a 1-length slice which we'll overwrite with
- // the `bytes_vectored` method.
- static S: &[u8] = &[0];
- const MAX_BUFS: usize = 64;
-
- // IoSlice isn't Copy, so we must expand this manually ;_;
- let mut slices: [IoSlice<'_>; MAX_BUFS] = [
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- IoSlice::new(S),
- ];
- let cnt = buf.bytes_vectored(&mut slices);
-
- let iovec = <&IoVec>::from(S);
- let mut vecs = [iovec; MAX_BUFS];
- for i in 0..cnt {
- vecs[i] = (*slices[i]).into();
- }
-
- match self.io.get_ref().write_bufs(&vecs[..cnt]) {
- Ok(n) => {
- buf.advance(n);
- Poll::Ready(Ok(n))
- }
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
- }
- Err(e) => Poll::Ready(Err(e)),
}
}
}
-impl TryFrom<TcpStream> for mio::net::TcpStream {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
- ///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
- ///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: TcpStream) -> Result<Self, Self::Error> {
- value.io.into_inner()
- }
-}
-
-impl TryFrom<net::TcpStream> for TcpStream {
+impl TryFrom<std::net::TcpStream> for TcpStream {
type Error = io::Error;
/// Consumes stream, returning the tokio I/O object.
///
/// This is equivalent to
/// [`TcpStream::from_std(stream)`](TcpStream::from_std).
- fn try_from(stream: net::TcpStream) -> Result<Self, Self::Error> {
+ fn try_from(stream: std::net::TcpStream) -> Result<Self, Self::Error> {
Self::from_std(stream)
}
}
@@ -860,15 +524,11 @@ impl TryFrom<net::TcpStream> for TcpStream {
// ===== impl Read / Write =====
impl AsyncRead for TcpStream {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
@@ -882,14 +542,6 @@ impl AsyncWrite for TcpStream {
self.poll_write_priv(cx, buf)
}
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- self.poll_write_buf_priv(cx, buf)
- }
-
#[inline]
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
// tcp flush is a no-op
@@ -922,14 +574,12 @@ mod sys {
#[cfg(windows)]
mod sys {
- // TODO: let's land these upstream with mio and then we can add them here.
- //
- // use std::os::windows::prelude::*;
- // use super::TcpStream;
- //
- // impl AsRawHandle for TcpStream {
- // fn as_raw_handle(&self) -> RawHandle {
- // self.io.get_ref().as_raw_handle()
- // }
- // }
+ use super::TcpStream;
+ use std::os::windows::prelude::*;
+
+ impl AsRawSocket for TcpStream {
+ fn as_raw_socket(&self) -> RawSocket {
+ self.io.get_ref().as_raw_socket()
+ }
+ }
}
diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs
index d43121a..c9bb0f8 100644
--- a/src/net/udp/mod.rs
+++ b/src/net/udp/mod.rs
@@ -1,7 +1,3 @@
//! UDP utility types.
pub(crate) mod socket;
-pub(crate) use socket::UdpSocket;
-
-mod split;
-pub use split::{RecvHalf, ReuniteError, SendHalf};
diff --git a/src/net/udp/socket.rs b/src/net/udp/socket.rs
index 97090a2..d13e92b 100644
--- a/src/net/udp/socket.rs
+++ b/src/net/udp/socket.rs
@@ -1,16 +1,110 @@
-use crate::future::poll_fn;
use crate::io::PollEvented;
-use crate::net::udp::split::{split, RecvHalf, SendHalf};
-use crate::net::ToSocketAddrs;
+use crate::net::{to_socket_addrs, ToSocketAddrs};
use std::convert::TryFrom;
use std::fmt;
use std::io;
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
-use std::task::{Context, Poll};
-cfg_udp! {
+cfg_net! {
/// A UDP socket
+ ///
+ /// UDP is "connectionless", unlike TCP. Meaning, regardless of what address you've bound to, a `UdpSocket`
+ /// is free to communicate with many different remotes. In tokio there are basically two main ways to use `UdpSocket`:
+ ///
+ /// * one to many: [`bind`](`UdpSocket::bind`) and use [`send_to`](`UdpSocket::send_to`)
+ /// and [`recv_from`](`UdpSocket::recv_from`) to communicate with many different addresses
+ /// * one to one: [`connect`](`UdpSocket::connect`) and associate with a single address, using [`send`](`UdpSocket::send`)
+ /// and [`recv`](`UdpSocket::recv`) to communicate only with that remote address
+ ///
+ /// `UdpSocket` can also be used concurrently to `send_to` and `recv_from` in different tasks,
+ /// all that's required is that you `Arc<UdpSocket>` and clone a reference for each task.
+ ///
+ /// # Streams
+ ///
+ /// If you need to listen over UDP and produce a [`Stream`](`crate::stream::Stream`), you can look
+ /// at [`UdpFramed`].
+ ///
+ /// [`UdpFramed`]: https://docs.rs/tokio-util/latest/tokio_util/udp/struct.UdpFramed.html
+ ///
+ /// # Example: one to many (bind)
+ ///
+ /// Using `bind` we can create a simple echo server that sends and recv's with many different clients:
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080").await?;
+ /// let mut buf = [0; 1024];
+ /// loop {
+ /// let (len, addr) = sock.recv_from(&mut buf).await?;
+ /// println!("{:?} bytes received from {:?}", len, addr);
+ ///
+ /// let len = sock.send_to(&buf[..len], addr).await?;
+ /// println!("{:?} bytes sent", len);
+ /// }
+ /// }
+ /// ```
+ ///
+ /// # Example: one to one (connect)
+ ///
+ /// Or using `connect` we can echo with a single remote address using `send` and `recv`:
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080").await?;
+ ///
+ /// let remote_addr = "127.0.0.1:59611";
+ /// sock.connect(remote_addr).await?;
+ /// let mut buf = [0; 1024];
+ /// loop {
+ /// let len = sock.recv(&mut buf).await?;
+ /// println!("{:?} bytes received from {:?}", len, remote_addr);
+ ///
+ /// let len = sock.send(&buf[..len]).await?;
+ /// println!("{:?} bytes sent", len);
+ /// }
+ /// }
+ /// ```
+ ///
+ /// # Example: Sending/Receiving concurrently
+ ///
+ /// Because `send_to` and `recv_from` take `&self`. It's perfectly alright to `Arc<UdpSocket>`
+ /// and share the references to multiple tasks, in order to send/receive concurrently. Here is
+ /// a similar "echo" example but that supports concurrent sending/receiving:
+ ///
+ /// ```no_run
+ /// use tokio::{net::UdpSocket, sync::mpsc};
+ /// use std::{io, net::SocketAddr, sync::Arc};
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?;
+ /// let r = Arc::new(sock);
+ /// let s = r.clone();
+ /// let (tx, mut rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(1_000);
+ ///
+ /// tokio::spawn(async move {
+ /// while let Some((bytes, addr)) = rx.recv().await {
+ /// let len = s.send_to(&bytes, &addr).await.unwrap();
+ /// println!("{:?} bytes sent", len);
+ /// }
+ /// });
+ ///
+ /// let mut buf = [0; 1024];
+ /// loop {
+ /// let (len, addr) = r.recv_from(&mut buf).await?;
+ /// println!("{:?} bytes received from {:?}", len, addr);
+ /// tx.send((buf[..len].to_vec(), addr)).await.unwrap();
+ /// }
+ /// }
+ /// ```
+ ///
pub struct UdpSocket {
io: PollEvented<mio::net::UdpSocket>,
}
@@ -19,8 +113,23 @@ cfg_udp! {
impl UdpSocket {
/// This function will create a new UDP socket and attempt to bind it to
/// the `addr` provided.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080").await?;
+ /// // use `sock`
+ /// # let _ = sock;
+ /// Ok(())
+ /// }
+ /// ```
pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
- let addrs = addr.to_socket_addrs().await?;
+ let addrs = to_socket_addrs(addr).await?;
let mut last_err = None;
for addr in addrs {
@@ -39,7 +148,7 @@ impl UdpSocket {
}
fn bind_addr(addr: SocketAddr) -> io::Result<UdpSocket> {
- let sys = mio::net::UdpSocket::bind(&addr)?;
+ let sys = mio::net::UdpSocket::bind(addr)?;
UdpSocket::new(sys)
}
@@ -64,21 +173,45 @@ impl UdpSocket {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
+ /// let std_sock = std::net::UdpSocket::bind(addr)?;
+ /// let sock = UdpSocket::from_std(std_sock)?;
+ /// // use `sock`
+ /// # Ok(())
+ /// # }
+ /// ```
pub fn from_std(socket: net::UdpSocket) -> io::Result<UdpSocket> {
- let io = mio::net::UdpSocket::from_socket(socket)?;
- let io = PollEvented::new(io)?;
- Ok(UdpSocket { io })
- }
-
- /// Splits the `UdpSocket` into a receive half and a send half. The two parts
- /// can be used to receive and send datagrams concurrently, even from two
- /// different tasks.
- pub fn split(self) -> (RecvHalf, SendHalf) {
- split(self)
+ let io = mio::net::UdpSocket::from_std(socket);
+ UdpSocket::new(io)
}
/// Returns the local address that this socket is bound to.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let addr = "0.0.0.0:8080".parse::<SocketAddr>().unwrap();
+ /// let sock = UdpSocket::bind(addr).await?;
+ /// // the address the socket is bound to
+ /// let local_addr = sock.local_addr()?;
+ /// # Ok(())
+ /// # }
+ /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io.get_ref().local_addr()
}
@@ -86,8 +219,29 @@ impl UdpSocket {
/// Connects the UDP socket setting the default destination for send() and
/// limiting packets that are read via recv from the address specified in
/// `addr`.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?;
+ ///
+ /// let remote_addr = "127.0.0.1:59600".parse::<SocketAddr>().unwrap();
+ /// sock.connect(remote_addr).await?;
+ /// let mut buf = [0u8; 32];
+ /// // recv from remote_addr
+ /// let len = sock.recv(&mut buf).await?;
+ /// // send to remote_addr
+ /// let _len = sock.send(&buf[..len]).await?;
+ /// # Ok(())
+ /// # }
+ /// ```
pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
- let addrs = addr.to_socket_addrs().await?;
+ let addrs = to_socket_addrs(addr).await?;
let mut last_err = None;
for addr in addrs {
@@ -112,31 +266,24 @@ impl UdpSocket {
/// will resolve to an error if the socket is not connected.
///
/// [`connect`]: method@Self::connect
- pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
- poll_fn(|cx| self.poll_send(cx, buf)).await
- }
-
- // Poll IO functions that takes `&self` are provided for the split API.
- //
- // They are not public because (taken from the doc of `PollEvented`):
- //
- // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
- // caller must ensure that there are at most two tasks that use a
- // `PollEvented` instance concurrently. One for reading and one for writing.
- // While violating this requirement is "safe" from a Rust memory model point
- // of view, it will result in unexpected behavior in the form of lost
- // notifications and tasks hanging.
- #[doc(hidden)]
- pub fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().send(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
+ pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::WRITABLE, |sock| sock.send(buf))
+ .await
+ }
+
+ /// Try to send data on the socket to the remote address to which it is
+ /// connected.
+ ///
+ /// # Returns
+ ///
+ /// If successfull, the number of bytes sent is returned. Users
+ /// should ensure that when the remote cannot receive, the
+ /// [`ErrorKind::WouldBlock`] is properly handled.
+ ///
+ /// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock
+ pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> {
+ self.io.get_ref().send(buf)
}
/// Returns a future that receives a single datagram message on the socket from
@@ -151,21 +298,10 @@ impl UdpSocket {
/// will fail if the socket is not connected.
///
/// [`connect`]: method@Self::connect
- pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- poll_fn(|cx| self.poll_recv(cx, buf)).await
- }
-
- #[doc(hidden)]
- pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().recv(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
+ pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::READABLE, |sock| sock.recv(buf))
+ .await
}
/// Returns a future that sends data on the socket to the given address.
@@ -173,11 +309,27 @@ impl UdpSocket {
///
/// The future will resolve to an error if the IP version of the socket does
/// not match that of `target`.
- pub async fn send_to<A: ToSocketAddrs>(&mut self, buf: &[u8], target: A) -> io::Result<usize> {
- let mut addrs = target.to_socket_addrs().await?;
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?;
+ /// let buf = b"hello world";
+ /// let remote_addr = "127.0.0.1:58000".parse::<SocketAddr>().unwrap();
+ /// let _len = sock.send_to(&buf[..], remote_addr).await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> io::Result<usize> {
+ let mut addrs = to_socket_addrs(target).await?;
match addrs.next() {
- Some(target) => poll_fn(|cx| self.poll_send_to(cx, buf, &target)).await,
+ Some(target) => self.send_to_addr(buf, target).await,
None => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"no addresses to send data to",
@@ -185,23 +337,42 @@ impl UdpSocket {
}
}
- // TODO: Public or not?
- #[doc(hidden)]
- pub fn poll_send_to(
- &self,
- cx: &mut Context<'_>,
- buf: &[u8],
- target: &SocketAddr,
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().send_to(buf, target) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
+ /// Try to send data on the socket to the given address, but if the send is blocked
+ /// this will return right away.
+ ///
+ /// # Returns
+ ///
+ /// If successfull, returns the number of bytes sent
+ ///
+ /// Users should ensure that when the remote cannot receive, the
+ /// [`ErrorKind::WouldBlock`] is properly handled. An error can also occur
+ /// if the IP version of the socket does not match that of `target`.
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?;
+ /// let buf = b"hello world";
+ /// let remote_addr = "127.0.0.1:58000".parse::<SocketAddr>().unwrap();
+ /// let _len = sock.try_send_to(&buf[..], remote_addr)?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// [`ErrorKind::WouldBlock`]: std::io::ErrorKind::WouldBlock
+ pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
+ self.io.get_ref().send_to(buf, target)
+ }
+
+ async fn send_to_addr(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::WRITABLE, |sock| sock.send_to(buf, target))
+ .await
}
/// Returns a future that receives a single datagram on the socket. On success,
@@ -210,25 +381,26 @@ impl UdpSocket {
/// The function must be called with valid byte array `buf` of sufficient size
/// to hold the message bytes. If a message is too long to fit in the supplied
/// buffer, excess bytes may be discarded.
- pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
- poll_fn(|cx| self.poll_recv_from(cx, buf)).await
- }
-
- #[doc(hidden)]
- pub fn poll_recv_from(
- &self,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<Result<(usize, SocketAddr), io::Error>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().recv_from(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
+ ///
+ /// # Example
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::{io, net::SocketAddr};
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("0.0.0.0:8080".parse::<SocketAddr>().unwrap()).await?;
+ /// let mut buf = [0u8; 32];
+ /// let (len, addr) = sock.recv_from(&mut buf).await?;
+ /// println!("received {:?} bytes from {:?}", len, addr);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+ self.io
+ .async_io(mio::Interest::READABLE, |sock| sock.recv_from(buf))
+ .await
}
/// Gets the value of the `SO_BROADCAST` option for this socket.
@@ -315,6 +487,20 @@ impl UdpSocket {
/// For more information about this option, see [`set_ttl`].
///
/// [`set_ttl`]: method@Self::set_ttl
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::io;
+ ///
+ /// # async fn dox() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("127.0.0.1:8080").await?;
+ ///
+ /// println!("{:?}", sock.ttl()?);
+ /// # Ok(())
+ /// # }
+ /// ```
pub fn ttl(&self) -> io::Result<u32> {
self.io.get_ref().ttl()
}
@@ -323,6 +509,20 @@ impl UdpSocket {
///
/// This value sets the time-to-live field that is used in every packet sent
/// from this socket.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::net::UdpSocket;
+ /// # use std::io;
+ ///
+ /// # async fn dox() -> io::Result<()> {
+ /// let sock = UdpSocket::bind("127.0.0.1:8080").await?;
+ /// sock.set_ttl(60)?;
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.io.get_ref().set_ttl(ttl)
}
@@ -366,28 +566,14 @@ impl UdpSocket {
}
}
-impl TryFrom<UdpSocket> for mio::net::UdpSocket {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
- ///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
- ///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: UdpSocket) -> Result<Self, Self::Error> {
- value.io.into_inner()
- }
-}
-
-impl TryFrom<net::UdpSocket> for UdpSocket {
+impl TryFrom<std::net::UdpSocket> for UdpSocket {
type Error = io::Error;
/// Consumes stream, returning the tokio I/O object.
///
/// This is equivalent to
/// [`UdpSocket::from_std(stream)`](UdpSocket::from_std).
- fn try_from(stream: net::UdpSocket) -> Result<Self, Self::Error> {
+ fn try_from(stream: std::net::UdpSocket) -> Result<Self, Self::Error> {
Self::from_std(stream)
}
}
@@ -412,14 +598,12 @@ mod sys {
#[cfg(windows)]
mod sys {
- // TODO: let's land these upstream with mio and then we can add them here.
- //
- // use std::os::windows::prelude::*;
- // use super::UdpSocket;
- //
- // impl AsRawHandle for UdpSocket {
- // fn as_raw_handle(&self) -> RawHandle {
- // self.io.get_ref().as_raw_handle()
- // }
- // }
+ use super::UdpSocket;
+ use std::os::windows::prelude::*;
+
+ impl AsRawSocket for UdpSocket {
+ fn as_raw_socket(&self) -> RawSocket {
+ self.io.get_ref().as_raw_socket()
+ }
+ }
}
diff --git a/src/net/udp/split.rs b/src/net/udp/split.rs
deleted file mode 100644
index e8d434a..0000000
--- a/src/net/udp/split.rs
+++ /dev/null
@@ -1,148 +0,0 @@
-//! [`UdpSocket`](crate::net::UdpSocket) split support.
-//!
-//! The [`split`](method@crate::net::UdpSocket::split) method splits a
-//! `UdpSocket` into a receive half and a send half, which can be used to
-//! receive and send datagrams concurrently, even from two different tasks.
-//!
-//! The halves provide access to the underlying socket, implementing
-//! `AsRef<UdpSocket>`. This allows you to call `UdpSocket` methods that takes
-//! `&self`, e.g., to get local address, to get and set socket options, to join
-//! or leave multicast groups, etc.
-//!
-//! The halves can be reunited to the original socket with their `reunite`
-//! methods.
-
-use crate::future::poll_fn;
-use crate::net::udp::UdpSocket;
-
-use std::error::Error;
-use std::fmt;
-use std::io;
-use std::net::SocketAddr;
-use std::sync::Arc;
-
-/// The send half after [`split`](super::UdpSocket::split).
-///
-/// Use [`send_to`](method@Self::send_to) or [`send`](method@Self::send) to send
-/// datagrams.
-#[derive(Debug)]
-pub struct SendHalf(Arc<UdpSocket>);
-
-/// The recv half after [`split`](super::UdpSocket::split).
-///
-/// Use [`recv_from`](method@Self::recv_from) or [`recv`](method@Self::recv) to receive
-/// datagrams.
-#[derive(Debug)]
-pub struct RecvHalf(Arc<UdpSocket>);
-
-pub(crate) fn split(socket: UdpSocket) -> (RecvHalf, SendHalf) {
- let shared = Arc::new(socket);
- let send = shared.clone();
- let recv = shared;
- (RecvHalf(recv), SendHalf(send))
-}
-
-/// Error indicating two halves were not from the same socket, and thus could
-/// not be `reunite`d.
-#[derive(Debug)]
-pub struct ReuniteError(pub SendHalf, pub RecvHalf);
-
-impl fmt::Display for ReuniteError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "tried to reunite halves that are not from the same socket"
- )
- }
-}
-
-impl Error for ReuniteError {}
-
-fn reunite(s: SendHalf, r: RecvHalf) -> Result<UdpSocket, ReuniteError> {
- if Arc::ptr_eq(&s.0, &r.0) {
- drop(r);
- // Only two instances of the `Arc` are ever created, one for the
- // receiver and one for the sender, and those `Arc`s are never exposed
- // externally. And so when we drop one here, the other one must be the
- // only remaining one.
- Ok(Arc::try_unwrap(s.0).expect("udp: try_unwrap failed in reunite"))
- } else {
- Err(ReuniteError(s, r))
- }
-}
-
-impl RecvHalf {
- /// Attempts to put the two "halves" of a `UdpSocket` back together and
- /// recover the original socket. Succeeds only if the two "halves"
- /// originated from the same call to `UdpSocket::split`.
- pub fn reunite(self, other: SendHalf) -> Result<UdpSocket, ReuniteError> {
- reunite(other, self)
- }
-
- /// Returns a future that receives a single datagram on the socket. On success,
- /// the future resolves to the number of bytes read and the origin.
- ///
- /// The function must be called with valid byte array `buf` of sufficient size
- /// to hold the message bytes. If a message is too long to fit in the supplied
- /// buffer, excess bytes may be discarded.
- pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
- poll_fn(|cx| self.0.poll_recv_from(cx, buf)).await
- }
-
- /// Returns a future that receives a single datagram message on the socket from
- /// the remote address to which it is connected. On success, the future will resolve
- /// to the number of bytes read.
- ///
- /// The function must be called with valid byte array `buf` of sufficient size to
- /// hold the message bytes. If a message is too long to fit in the supplied buffer,
- /// excess bytes may be discarded.
- ///
- /// The [`connect`] method will connect this socket to a remote address. The future
- /// will fail if the socket is not connected.
- ///
- /// [`connect`]: super::UdpSocket::connect
- pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- poll_fn(|cx| self.0.poll_recv(cx, buf)).await
- }
-}
-
-impl SendHalf {
- /// Attempts to put the two "halves" of a `UdpSocket` back together and
- /// recover the original socket. Succeeds only if the two "halves"
- /// originated from the same call to `UdpSocket::split`.
- pub fn reunite(self, other: RecvHalf) -> Result<UdpSocket, ReuniteError> {
- reunite(self, other)
- }
-
- /// Returns a future that sends data on the socket to the given address.
- /// On success, the future will resolve to the number of bytes written.
- ///
- /// The future will resolve to an error if the IP version of the socket does
- /// not match that of `target`.
- pub async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
- poll_fn(|cx| self.0.poll_send_to(cx, buf, target)).await
- }
-
- /// Returns a future that sends data on the socket to the remote address to which it is connected.
- /// On success, the future will resolve to the number of bytes written.
- ///
- /// The [`connect`] method will connect this socket to a remote address. The future
- /// will resolve to an error if the socket is not connected.
- ///
- /// [`connect`]: super::UdpSocket::connect
- pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
- poll_fn(|cx| self.0.poll_send(cx, buf)).await
- }
-}
-
-impl AsRef<UdpSocket> for SendHalf {
- fn as_ref(&self) -> &UdpSocket {
- &self.0
- }
-}
-
-impl AsRef<UdpSocket> for RecvHalf {
- fn as_ref(&self) -> &UdpSocket {
- &self.0
- }
-}
diff --git a/src/net/unix/datagram.rs b/src/net/unix/datagram.rs
deleted file mode 100644
index ff0f424..0000000
--- a/src/net/unix/datagram.rs
+++ /dev/null
@@ -1,242 +0,0 @@
-use crate::future::poll_fn;
-use crate::io::PollEvented;
-
-use std::convert::TryFrom;
-use std::fmt;
-use std::io;
-use std::net::Shutdown;
-use std::os::unix::io::{AsRawFd, RawFd};
-use std::os::unix::net::{self, SocketAddr};
-use std::path::Path;
-use std::task::{Context, Poll};
-
-cfg_uds! {
- /// An I/O object representing a Unix datagram socket.
- pub struct UnixDatagram {
- io: PollEvented<mio_uds::UnixDatagram>,
- }
-}
-
-impl UnixDatagram {
- /// Creates a new `UnixDatagram` bound to the specified path.
- pub fn bind<P>(path: P) -> io::Result<UnixDatagram>
- where
- P: AsRef<Path>,
- {
- let socket = mio_uds::UnixDatagram::bind(path)?;
- UnixDatagram::new(socket)
- }
-
- /// Creates an unnamed pair of connected sockets.
- ///
- /// This function will create a pair of interconnected Unix sockets for
- /// communicating back and forth between one another. Each socket will
- /// be associated with the default event loop's handle.
- pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> {
- let (a, b) = mio_uds::UnixDatagram::pair()?;
- let a = UnixDatagram::new(a)?;
- let b = UnixDatagram::new(b)?;
-
- Ok((a, b))
- }
-
- /// Consumes a `UnixDatagram` in the standard library and returns a
- /// nonblocking `UnixDatagram` from this crate.
- ///
- /// The returned datagram will be associated with the given event loop
- /// specified by `handle` and is ready to perform I/O.
- ///
- /// # Panics
- ///
- /// This function panics if thread-local runtime is not set.
- ///
- /// The runtime is usually set implicitly when this function is called
- /// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> {
- let socket = mio_uds::UnixDatagram::from_datagram(datagram)?;
- let io = PollEvented::new(socket)?;
- Ok(UnixDatagram { io })
- }
-
- fn new(socket: mio_uds::UnixDatagram) -> io::Result<UnixDatagram> {
- let io = PollEvented::new(socket)?;
- Ok(UnixDatagram { io })
- }
-
- /// Creates a new `UnixDatagram` which is not bound to any address.
- pub fn unbound() -> io::Result<UnixDatagram> {
- let socket = mio_uds::UnixDatagram::unbound()?;
- UnixDatagram::new(socket)
- }
-
- /// Connects the socket to the specified address.
- ///
- /// The `send` method may be used to send data to the specified address.
- /// `recv` and `recv_from` will only receive data from that address.
- pub fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> {
- self.io.get_ref().connect(path)
- }
-
- /// Sends data on the socket to the socket's peer.
- pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
- poll_fn(|cx| self.poll_send_priv(cx, buf)).await
- }
-
- // Poll IO functions that takes `&self` are provided for the split API.
- //
- // They are not public because (taken from the doc of `PollEvented`):
- //
- // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
- // caller must ensure that there are at most two tasks that use a
- // `PollEvented` instance concurrently. One for reading and one for writing.
- // While violating this requirement is "safe" from a Rust memory model point
- // of view, it will result in unexpected behavior in the form of lost
- // notifications and tasks hanging.
- pub(crate) fn poll_send_priv(
- &self,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().send(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
- }
-
- /// Receives data from the socket.
- pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- poll_fn(|cx| self.poll_recv_priv(cx, buf)).await
- }
-
- pub(crate) fn poll_recv_priv(
- &self,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().recv(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
- }
-
- /// Sends data on the socket to the specified address.
- pub async fn send_to<P>(&mut self, buf: &[u8], target: P) -> io::Result<usize>
- where
- P: AsRef<Path> + Unpin,
- {
- poll_fn(|cx| self.poll_send_to_priv(cx, buf, target.as_ref())).await
- }
-
- pub(crate) fn poll_send_to_priv(
- &self,
- cx: &mut Context<'_>,
- buf: &[u8],
- target: &Path,
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().send_to(buf, target) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
- }
-
- /// Receives data from the socket.
- pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
- poll_fn(|cx| self.poll_recv_from_priv(cx, buf)).await
- }
-
- pub(crate) fn poll_recv_from_priv(
- &self,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<Result<(usize, SocketAddr), io::Error>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().recv_from(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
- }
- x => Poll::Ready(x),
- }
- }
-
- /// Returns the local address that this socket is bound to.
- pub fn local_addr(&self) -> io::Result<SocketAddr> {
- self.io.get_ref().local_addr()
- }
-
- /// Returns the address of this socket's peer.
- ///
- /// The `connect` method will connect the socket to a peer.
- pub fn peer_addr(&self) -> io::Result<SocketAddr> {
- self.io.get_ref().peer_addr()
- }
-
- /// Returns the value of the `SO_ERROR` option.
- pub fn take_error(&self) -> io::Result<Option<io::Error>> {
- self.io.get_ref().take_error()
- }
-
- /// Shuts down the read, write, or both halves of this connection.
- ///
- /// This function will cause all pending and future I/O calls on the
- /// specified portions to immediately return with an appropriate value
- /// (see the documentation of `Shutdown`).
- pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
- self.io.get_ref().shutdown(how)
- }
-}
-
-impl TryFrom<UnixDatagram> for mio_uds::UnixDatagram {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
- ///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
- ///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: UnixDatagram) -> Result<Self, Self::Error> {
- value.io.into_inner()
- }
-}
-
-impl TryFrom<net::UnixDatagram> for UnixDatagram {
- type Error = io::Error;
-
- /// Consumes stream, returning the tokio I/O object.
- ///
- /// This is equivalent to
- /// [`UnixDatagram::from_std(stream)`](UnixDatagram::from_std).
- fn try_from(stream: net::UnixDatagram) -> Result<Self, Self::Error> {
- Self::from_std(stream)
- }
-}
-
-impl fmt::Debug for UnixDatagram {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.io.get_ref().fmt(f)
- }
-}
-
-impl AsRawFd for UnixDatagram {
- fn as_raw_fd(&self) -> RawFd {
- self.io.get_ref().as_raw_fd()
- }
-}
diff --git a/src/net/unix/datagram/mod.rs b/src/net/unix/datagram/mod.rs
new file mode 100644
index 0000000..6268b4a
--- /dev/null
+++ b/src/net/unix/datagram/mod.rs
@@ -0,0 +1,3 @@
+//! Unix datagram types.
+
+pub(crate) mod socket;
diff --git a/src/net/unix/datagram/socket.rs b/src/net/unix/datagram/socket.rs
new file mode 100644
index 0000000..3ae66d1
--- /dev/null
+++ b/src/net/unix/datagram/socket.rs
@@ -0,0 +1,731 @@
+use crate::io::PollEvented;
+use crate::net::unix::SocketAddr;
+
+use std::convert::TryFrom;
+use std::fmt;
+use std::io;
+use std::net::Shutdown;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net;
+use std::path::Path;
+
+cfg_net_unix! {
+ /// An I/O object representing a Unix datagram socket.
+ ///
+ /// A socket can be either named (associated with a filesystem path) or
+ /// unnamed.
+ ///
+ /// **Note:** named sockets are persisted even after the object is dropped
+ /// and the program has exited, and cannot be reconnected. It is advised
+ /// that you either check for and unlink the existing socket if it exists,
+ /// or use a temporary file that is guaranteed to not already exist.
+ ///
+ /// # Examples
+ /// Using named sockets, associated with a filesystem path:
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind each socket to a filesystem path
+ /// let tx_path = tmp.path().join("tx");
+ /// let tx = UnixDatagram::bind(&tx_path)?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// let bytes = b"hello world";
+ /// tx.send_to(bytes, &rx_path).await?;
+ ///
+ /// let mut buf = vec![0u8; 24];
+ /// let (size, addr) = rx.recv_from(&mut buf).await?;
+ ///
+ /// let dgram = &buf[..size];
+ /// assert_eq!(dgram, bytes);
+ /// assert_eq!(addr.as_pathname().unwrap(), &tx_path);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// Using unnamed sockets, created as a pair
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create the pair of sockets
+ /// let (sock1, sock2) = UnixDatagram::pair()?;
+ ///
+ /// // Since the sockets are paired, the paired send/recv
+ /// // functions can be used
+ /// let bytes = b"hello world";
+ /// sock1.send(bytes).await?;
+ ///
+ /// let mut buff = vec![0u8; 24];
+ /// let size = sock2.recv(&mut buff).await?;
+ ///
+ /// let dgram = &buff[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub struct UnixDatagram {
+ io: PollEvented<mio::net::UnixDatagram>,
+ }
+}
+
+impl UnixDatagram {
+ /// Creates a new `UnixDatagram` bound to the specified path.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind the socket to a filesystem path
+ /// let socket_path = tmp.path().join("socket");
+ /// let socket = UnixDatagram::bind(&socket_path)?;
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn bind<P>(path: P) -> io::Result<UnixDatagram>
+ where
+ P: AsRef<Path>,
+ {
+ let socket = mio::net::UnixDatagram::bind(path)?;
+ UnixDatagram::new(socket)
+ }
+
+ /// Creates an unnamed pair of connected sockets.
+ ///
+ /// This function will create a pair of interconnected Unix sockets for
+ /// communicating back and forth between one another.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create the pair of sockets
+ /// let (sock1, sock2) = UnixDatagram::pair()?;
+ ///
+ /// // Since the sockets are paired, the paired send/recv
+ /// // functions can be used
+ /// let bytes = b"hail eris";
+ /// sock1.send(bytes).await?;
+ ///
+ /// let mut buff = vec![0u8; 24];
+ /// let size = sock2.recv(&mut buff).await?;
+ ///
+ /// let dgram = &buff[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> {
+ let (a, b) = mio::net::UnixDatagram::pair()?;
+ let a = UnixDatagram::new(a)?;
+ let b = UnixDatagram::new(b)?;
+
+ Ok((a, b))
+ }
+
+ /// Consumes a `UnixDatagram` in the standard library and returns a
+ /// nonblocking `UnixDatagram` from this crate.
+ ///
+ /// The returned datagram will be associated with the given event loop
+ /// specified by `handle` and is ready to perform I/O.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if thread-local runtime is not set.
+ ///
+ /// The runtime is usually set implicitly when this function is called
+ /// from a future driven by a Tokio runtime, otherwise runtime can be set
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use std::os::unix::net::UnixDatagram as StdUDS;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind the socket to a filesystem path
+ /// let socket_path = tmp.path().join("socket");
+ /// let std_socket = StdUDS::bind(&socket_path)?;
+ /// let tokio_socket = UnixDatagram::from_std(std_socket)?;
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn from_std(datagram: net::UnixDatagram) -> io::Result<UnixDatagram> {
+ let socket = mio::net::UnixDatagram::from_std(datagram);
+ let io = PollEvented::new(socket)?;
+ Ok(UnixDatagram { io })
+ }
+
+ fn new(socket: mio::net::UnixDatagram) -> io::Result<UnixDatagram> {
+ let io = PollEvented::new(socket)?;
+ Ok(UnixDatagram { io })
+ }
+
+ /// Creates a new `UnixDatagram` which is not bound to any address.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // Create an unbound socket
+ /// let tx = UnixDatagram::unbound()?;
+ ///
+ /// // Create another, bound socket
+ /// let tmp = tempdir()?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// // Send to the bound socket
+ /// let bytes = b"hello world";
+ /// tx.send_to(bytes, &rx_path).await?;
+ ///
+ /// let mut buf = vec![0u8; 24];
+ /// let (size, addr) = rx.recv_from(&mut buf).await?;
+ ///
+ /// let dgram = &buf[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn unbound() -> io::Result<UnixDatagram> {
+ let socket = mio::net::UnixDatagram::unbound()?;
+ UnixDatagram::new(socket)
+ }
+
+ /// Connects the socket to the specified address.
+ ///
+ /// The `send` method may be used to send data to the specified address.
+ /// `recv` and `recv_from` will only receive data from that address.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // Create an unbound socket
+ /// let tx = UnixDatagram::unbound()?;
+ ///
+ /// // Create another, bound socket
+ /// let tmp = tempdir()?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// // Connect to the bound socket
+ /// tx.connect(&rx_path)?;
+ ///
+ /// // Send to the bound socket
+ /// let bytes = b"hello world";
+ /// tx.send(bytes).await?;
+ ///
+ /// let mut buf = vec![0u8; 24];
+ /// let (size, addr) = rx.recv_from(&mut buf).await?;
+ ///
+ /// let dgram = &buf[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> {
+ self.io.get_ref().connect(path)
+ }
+
+ /// Sends data on the socket to the socket's peer.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create the pair of sockets
+ /// let (sock1, sock2) = UnixDatagram::pair()?;
+ ///
+ /// // Since the sockets are paired, the paired send/recv
+ /// // functions can be used
+ /// let bytes = b"hello world";
+ /// sock1.send(bytes).await?;
+ ///
+ /// let mut buff = vec![0u8; 24];
+ /// let size = sock2.recv(&mut buff).await?;
+ ///
+ /// let dgram = &buff[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::WRITABLE, |sock| sock.send(buf))
+ .await
+ }
+
+ /// Try to send a datagram to the peer without waiting.
+ ///
+ /// # Examples
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// let bytes = b"bytes";
+ /// // We use a socket pair so that they are assigned
+ /// // each other as a peer.
+ /// let (first, second) = UnixDatagram::pair()?;
+ ///
+ /// let size = first.try_send(bytes)?;
+ /// assert_eq!(size, bytes.len());
+ ///
+ /// let mut buffer = vec![0u8; 24];
+ /// let size = second.try_recv(&mut buffer)?;
+ ///
+ /// let dgram = &buffer[..size];
+ /// assert_eq!(dgram, bytes);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> {
+ self.io.get_ref().send(buf)
+ }
+
+ /// Try to send a datagram to the peer without waiting.
+ ///
+ /// # Examples
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// let bytes = b"bytes";
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir().unwrap();
+ ///
+ /// let server_path = tmp.path().join("server");
+ /// let server = UnixDatagram::bind(&server_path)?;
+ ///
+ /// let client_path = tmp.path().join("client");
+ /// let client = UnixDatagram::bind(&client_path)?;
+ ///
+ /// let size = client.try_send_to(bytes, &server_path)?;
+ /// assert_eq!(size, bytes.len());
+ ///
+ /// let mut buffer = vec![0u8; 24];
+ /// let (size, addr) = server.try_recv_from(&mut buffer)?;
+ ///
+ /// let dgram = &buffer[..size];
+ /// assert_eq!(dgram, bytes);
+ /// assert_eq!(addr.as_pathname().unwrap(), &client_path);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize>
+ where
+ P: AsRef<Path>,
+ {
+ self.io.get_ref().send_to(buf, target)
+ }
+
+ /// Receives data from the socket.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create the pair of sockets
+ /// let (sock1, sock2) = UnixDatagram::pair()?;
+ ///
+ /// // Since the sockets are paired, the paired send/recv
+ /// // functions can be used
+ /// let bytes = b"hello world";
+ /// sock1.send(bytes).await?;
+ ///
+ /// let mut buff = vec![0u8; 24];
+ /// let size = sock2.recv(&mut buff).await?;
+ ///
+ /// let dgram = &buff[..size];
+ /// assert_eq!(dgram, bytes);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
+ self.io
+ .async_io(mio::Interest::READABLE, |sock| sock.recv(buf))
+ .await
+ }
+
+ /// Try to receive a datagram from the peer without waiting.
+ ///
+ /// # Examples
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// let bytes = b"bytes";
+ /// // We use a socket pair so that they are assigned
+ /// // each other as a peer.
+ /// let (first, second) = UnixDatagram::pair()?;
+ ///
+ /// let size = first.try_send(bytes)?;
+ /// assert_eq!(size, bytes.len());
+ ///
+ /// let mut buffer = vec![0u8; 24];
+ /// let size = second.try_recv(&mut buffer)?;
+ ///
+ /// let dgram = &buffer[..size];
+ /// assert_eq!(dgram, bytes);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> {
+ self.io.get_ref().recv(buf)
+ }
+
+ /// Sends data on the socket to the specified address.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind each socket to a filesystem path
+ /// let tx_path = tmp.path().join("tx");
+ /// let tx = UnixDatagram::bind(&tx_path)?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// let bytes = b"hello world";
+ /// tx.send_to(bytes, &rx_path).await?;
+ ///
+ /// let mut buf = vec![0u8; 24];
+ /// let (size, addr) = rx.recv_from(&mut buf).await?;
+ ///
+ /// let dgram = &buf[..size];
+ /// assert_eq!(dgram, bytes);
+ /// assert_eq!(addr.as_pathname().unwrap(), &tx_path);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn send_to<P>(&self, buf: &[u8], target: P) -> io::Result<usize>
+ where
+ P: AsRef<Path>,
+ {
+ self.io
+ .async_io(mio::Interest::WRITABLE, |sock| {
+ sock.send_to(buf, target.as_ref())
+ })
+ .await
+ }
+
+ /// Receives data from the socket.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind each socket to a filesystem path
+ /// let tx_path = tmp.path().join("tx");
+ /// let tx = UnixDatagram::bind(&tx_path)?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// let bytes = b"hello world";
+ /// tx.send_to(bytes, &rx_path).await?;
+ ///
+ /// let mut buf = vec![0u8; 24];
+ /// let (size, addr) = rx.recv_from(&mut buf).await?;
+ ///
+ /// let dgram = &buf[..size];
+ /// assert_eq!(dgram, bytes);
+ /// assert_eq!(addr.as_pathname().unwrap(), &tx_path);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+ let (n, addr) = self
+ .io
+ .async_io(mio::Interest::READABLE, |sock| sock.recv_from(buf))
+ .await?;
+
+ Ok((n, SocketAddr(addr)))
+ }
+
+ /// Try to receive data from the socket without waiting.
+ ///
+ /// # Examples
+ /// ```
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// let bytes = b"bytes";
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir().unwrap();
+ ///
+ /// let server_path = tmp.path().join("server");
+ /// let server = UnixDatagram::bind(&server_path)?;
+ ///
+ /// let client_path = tmp.path().join("client");
+ /// let client = UnixDatagram::bind(&client_path)?;
+ ///
+ /// let size = client.try_send_to(bytes, &server_path)?;
+ /// assert_eq!(size, bytes.len());
+ ///
+ /// let mut buffer = vec![0u8; 24];
+ /// let (size, addr) = server.try_recv_from(&mut buffer)?;
+ ///
+ /// let dgram = &buffer[..size];
+ /// assert_eq!(dgram, bytes);
+ /// assert_eq!(addr.as_pathname().unwrap(), &client_path);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+ let (n, addr) = self.io.get_ref().recv_from(buf)?;
+ Ok((n, SocketAddr(addr)))
+ }
+
+ /// Returns the local address that this socket is bound to.
+ ///
+ /// # Examples
+ /// For a socket bound to a local path
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // We use a temporary directory so that the socket
+ /// // files left by the bound sockets will get cleaned up.
+ /// let tmp = tempdir()?;
+ ///
+ /// // Bind socket to a filesystem path
+ /// let socket_path = tmp.path().join("socket");
+ /// let socket = UnixDatagram::bind(&socket_path)?;
+ ///
+ /// assert_eq!(socket.local_addr()?.as_pathname().unwrap(), &socket_path);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// For an unbound socket
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create an unbound socket
+ /// let socket = UnixDatagram::unbound()?;
+ ///
+ /// assert!(socket.local_addr()?.is_unnamed());
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn local_addr(&self) -> io::Result<SocketAddr> {
+ self.io.get_ref().local_addr().map(SocketAddr)
+ }
+
+ /// Returns the address of this socket's peer.
+ ///
+ /// The `connect` method will connect the socket to a peer.
+ ///
+ /// # Examples
+ /// For a peer with a local path
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use tempfile::tempdir;
+ ///
+ /// // Create an unbound socket
+ /// let tx = UnixDatagram::unbound()?;
+ ///
+ /// // Create another, bound socket
+ /// let tmp = tempdir()?;
+ /// let rx_path = tmp.path().join("rx");
+ /// let rx = UnixDatagram::bind(&rx_path)?;
+ ///
+ /// // Connect to the bound socket
+ /// tx.connect(&rx_path)?;
+ ///
+ /// assert_eq!(tx.peer_addr()?.as_pathname().unwrap(), &rx_path);
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ ///
+ /// For an unbound peer
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create the pair of sockets
+ /// let (sock1, sock2) = UnixDatagram::pair()?;
+ ///
+ /// assert!(sock1.peer_addr()?.is_unnamed());
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn peer_addr(&self) -> io::Result<SocketAddr> {
+ self.io.get_ref().peer_addr().map(SocketAddr)
+ }
+
+ /// Returns the value of the `SO_ERROR` option.
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ ///
+ /// // Create an unbound socket
+ /// let socket = UnixDatagram::unbound()?;
+ ///
+ /// if let Ok(Some(err)) = socket.take_error() {
+ /// println!("Got error: {:?}", err);
+ /// }
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn take_error(&self) -> io::Result<Option<io::Error>> {
+ self.io.get_ref().take_error()
+ }
+
+ /// Shuts down the read, write, or both halves of this connection.
+ ///
+ /// This function will cause all pending and future I/O calls on the
+ /// specified portions to immediately return with an appropriate value
+ /// (see the documentation of `Shutdown`).
+ ///
+ /// # Examples
+ /// ```
+ /// # use std::error::Error;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<(), Box<dyn Error>> {
+ /// use tokio::net::UnixDatagram;
+ /// use std::net::Shutdown;
+ ///
+ /// // Create an unbound socket
+ /// let (socket, other) = UnixDatagram::pair()?;
+ ///
+ /// socket.shutdown(Shutdown::Both)?;
+ ///
+ /// // NOTE: the following commented out code does NOT work as expected.
+ /// // Due to an underlying issue, the recv call will block indefinitely.
+ /// // See: https://github.com/tokio-rs/tokio/issues/1679
+ /// //let mut buff = vec![0u8; 24];
+ /// //let size = socket.recv(&mut buff).await?;
+ /// //assert_eq!(size, 0);
+ ///
+ /// let send_result = socket.send(b"hello world").await;
+ /// assert!(send_result.is_err());
+ ///
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
+ self.io.get_ref().shutdown(how)
+ }
+}
+
+impl TryFrom<std::os::unix::net::UnixDatagram> for UnixDatagram {
+ type Error = io::Error;
+
+ /// Consumes stream, returning the Tokio I/O object.
+ ///
+ /// This is equivalent to
+ /// [`UnixDatagram::from_std(stream)`](UnixDatagram::from_std).
+ fn try_from(stream: std::os::unix::net::UnixDatagram) -> Result<Self, Self::Error> {
+ Self::from_std(stream)
+ }
+}
+
+impl fmt::Debug for UnixDatagram {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.io.get_ref().fmt(f)
+ }
+}
+
+impl AsRawFd for UnixDatagram {
+ fn as_raw_fd(&self) -> RawFd {
+ self.io.get_ref().as_raw_fd()
+ }
+}
diff --git a/src/net/unix/incoming.rs b/src/net/unix/incoming.rs
deleted file mode 100644
index af49360..0000000
--- a/src/net/unix/incoming.rs
+++ /dev/null
@@ -1,42 +0,0 @@
-use crate::net::unix::{UnixListener, UnixStream};
-
-use std::io;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-/// Stream of listeners
-#[derive(Debug)]
-#[must_use = "streams do nothing unless polled"]
-pub struct Incoming<'a> {
- inner: &'a mut UnixListener,
-}
-
-impl Incoming<'_> {
- pub(crate) fn new(listener: &mut UnixListener) -> Incoming<'_> {
- Incoming { inner: listener }
- }
-
- /// Attempts to poll `UnixStream` by polling inner `UnixListener` to accept
- /// connection.
- ///
- /// If `UnixListener` isn't ready yet, `Poll::Pending` is returned and
- /// current task will be notified by a waker. Otherwise `Poll::Ready` with
- /// `Result` containing `UnixStream` will be returned.
- pub fn poll_accept(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<UnixStream>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Ok(socket))
- }
-}
-
-#[cfg(feature = "stream")]
-impl crate::stream::Stream for Incoming<'_> {
- type Item = io::Result<UnixStream>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let (socket, _) = ready!(self.inner.poll_accept(cx))?;
- Poll::Ready(Some(Ok(socket)))
- }
-}
diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs
index 9b76cb0..b272645 100644
--- a/src/net/unix/listener.rs
+++ b/src/net/unix/listener.rs
@@ -1,17 +1,15 @@
-use crate::future::poll_fn;
use crate::io::PollEvented;
-use crate::net::unix::{Incoming, UnixStream};
+use crate::net::unix::{SocketAddr, UnixStream};
-use mio::Ready;
use std::convert::TryFrom;
use std::fmt;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
-use std::os::unix::net::{self, SocketAddr};
+use std::os::unix::net;
use std::path::Path;
use std::task::{Context, Poll};
-cfg_uds! {
+cfg_net_unix! {
/// A Unix socket which can accept connections from other Unix sockets.
///
/// You can accept a new connection by using the [`accept`](`UnixListener::accept`) method. Alternatively `UnixListener`
@@ -47,7 +45,7 @@ cfg_uds! {
/// }
/// ```
pub struct UnixListener {
- io: PollEvented<mio_uds::UnixListener>,
+ io: PollEvented<mio::net::UnixListener>,
}
}
@@ -60,12 +58,12 @@ impl UnixListener {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
pub fn bind<P>(path: P) -> io::Result<UnixListener>
where
P: AsRef<Path>,
{
- let listener = mio_uds::UnixListener::bind(path)?;
+ let listener = mio::net::UnixListener::bind(path)?;
let io = PollEvented::new(listener)?;
Ok(UnixListener { io })
}
@@ -82,16 +80,16 @@ impl UnixListener {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
pub fn from_std(listener: net::UnixListener) -> io::Result<UnixListener> {
- let listener = mio_uds::UnixListener::from_listener(listener)?;
+ let listener = mio::net::UnixListener::from_std(listener);
let io = PollEvented::new(listener)?;
Ok(UnixListener { io })
}
/// Returns the local socket address of this listener.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
- self.io.get_ref().local_addr()
+ self.io.get_ref().local_addr().map(SocketAddr)
}
/// Returns the value of the `SO_ERROR` option.
@@ -100,117 +98,62 @@ impl UnixListener {
}
/// Accepts a new incoming connection to this listener.
- pub async fn accept(&mut self) -> io::Result<(UnixStream, SocketAddr)> {
- poll_fn(|cx| self.poll_accept(cx)).await
+ pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
+ let (mio, addr) = self
+ .io
+ .async_io(mio::Interest::READABLE, |sock| sock.accept())
+ .await?;
+
+ let addr = SocketAddr(addr);
+ let stream = UnixStream::new(mio)?;
+ Ok((stream, addr))
}
- pub(crate) fn poll_accept(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<(UnixStream, SocketAddr)>> {
- let (io, addr) = ready!(self.poll_accept_std(cx))?;
-
- let io = mio_uds::UnixStream::from_stream(io)?;
- Ok((UnixStream::new(io)?, addr)).into()
- }
-
- fn poll_accept_std(
- &mut self,
- cx: &mut Context<'_>,
- ) -> Poll<io::Result<(net::UnixStream, SocketAddr)>> {
- ready!(self.io.poll_read_ready(cx, Ready::readable()))?;
-
- match self.io.get_ref().accept_std() {
- Ok(None) => {
- self.io.clear_read_ready(cx, Ready::readable())?;
- Poll::Pending
- }
- Ok(Some((sock, addr))) => Ok((sock, addr)).into(),
- Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, Ready::readable())?;
- Poll::Pending
+ /// Polls to accept a new incoming connection to this listener.
+ ///
+ /// If there is no connection to accept, `Poll::Pending` is returned and
+ /// the current task will be notified by a waker.
+ ///
+ /// When ready, the most recent task that called `poll_accept` is notified.
+ /// The caller is responsble to ensure that `poll_accept` is called from a
+ /// single task. Failing to do this could result in tasks hanging.
+ pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(UnixStream, SocketAddr)>> {
+ loop {
+ let ev = ready!(self.io.poll_read_ready(cx))?;
+
+ match self.io.get_ref().accept() {
+ Ok((sock, addr)) => {
+ let addr = SocketAddr(addr);
+ let sock = UnixStream::new(sock)?;
+ return Poll::Ready(Ok((sock, addr)));
+ }
+ Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ Err(err) => return Err(err).into(),
}
- Err(err) => Err(err).into(),
}
}
-
- /// Returns a stream over the connections being received on this listener.
- ///
- /// Note that `UnixListener` also directly implements `Stream`.
- ///
- /// The returned stream will never return `None` and will also not yield the
- /// peer's `SocketAddr` structure. Iterating over it is equivalent to
- /// calling accept in a loop.
- ///
- /// # Errors
- ///
- /// Note that accepting a connection can lead to various errors and not all
- /// of them are necessarily fatal ‒ for example having too many open file
- /// descriptors or the other side closing the connection while it waits in
- /// an accept queue. These would terminate the stream if not handled in any
- /// way.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::net::UnixListener;
- /// use tokio::stream::StreamExt;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let mut listener = UnixListener::bind("/path/to/the/socket").unwrap();
- /// let mut incoming = listener.incoming();
- ///
- /// while let Some(stream) = incoming.next().await {
- /// match stream {
- /// Ok(stream) => {
- /// println!("new client!");
- /// }
- /// Err(e) => { /* connection failed */ }
- /// }
- /// }
- /// }
- /// ```
- pub fn incoming(&mut self) -> Incoming<'_> {
- Incoming::new(self)
- }
}
#[cfg(feature = "stream")]
impl crate::stream::Stream for UnixListener {
type Item = io::Result<UnixStream>;
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Self::Item>> {
+ fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, _) = ready!(self.poll_accept(cx))?;
Poll::Ready(Some(Ok(socket)))
}
}
-impl TryFrom<UnixListener> for mio_uds::UnixListener {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
- ///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
- ///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: UnixListener) -> Result<Self, Self::Error> {
- value.io.into_inner()
- }
-}
-
-impl TryFrom<net::UnixListener> for UnixListener {
+impl TryFrom<std::os::unix::net::UnixListener> for UnixListener {
type Error = io::Error;
/// Consumes stream, returning the tokio I/O object.
///
/// This is equivalent to
/// [`UnixListener::from_std(stream)`](UnixListener::from_std).
- fn try_from(stream: net::UnixListener) -> io::Result<Self> {
+ fn try_from(stream: std::os::unix::net::UnixListener) -> io::Result<Self> {
Self::from_std(stream)
}
}
diff --git a/src/net/unix/mod.rs b/src/net/unix/mod.rs
index ddba60d..19ee34a 100644
--- a/src/net/unix/mod.rs
+++ b/src/net/unix/mod.rs
@@ -1,16 +1,18 @@
//! Unix domain socket utility types
-pub(crate) mod datagram;
-
-mod incoming;
-pub use incoming::Incoming;
+pub mod datagram;
pub(crate) mod listener;
-pub(crate) use listener::UnixListener;
mod split;
pub use split::{ReadHalf, WriteHalf};
+mod split_owned;
+pub use split_owned::{OwnedReadHalf, OwnedWriteHalf, ReuniteError};
+
+mod socketaddr;
+pub use socketaddr::SocketAddr;
+
pub(crate) mod stream;
pub(crate) use stream::UnixStream;
diff --git a/src/net/unix/socketaddr.rs b/src/net/unix/socketaddr.rs
new file mode 100644
index 0000000..48f7b96
--- /dev/null
+++ b/src/net/unix/socketaddr.rs
@@ -0,0 +1,31 @@
+use std::fmt;
+use std::path::Path;
+
+/// An address associated with a Tokio Unix socket.
+pub struct SocketAddr(pub(super) mio::net::SocketAddr);
+
+impl SocketAddr {
+ /// Returns `true` if the address is unnamed.
+ ///
+ /// Documentation reflected in [`SocketAddr`]
+ ///
+ /// [`SocketAddr`]: std::os::unix::net::SocketAddr
+ pub fn is_unnamed(&self) -> bool {
+ self.0.is_unnamed()
+ }
+
+ /// Returns the contents of this address if it is a `pathname` address.
+ ///
+ /// Documentation reflected in [`SocketAddr`]
+ ///
+ /// [`SocketAddr`]: std::os::unix::net::SocketAddr
+ pub fn as_pathname(&self) -> Option<&Path> {
+ self.0.as_pathname()
+ }
+}
+
+impl fmt::Debug for SocketAddr {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(fmt)
+ }
+}
diff --git a/src/net/unix/split.rs b/src/net/unix/split.rs
index 9b9fa5e..460bbc1 100644
--- a/src/net/unix/split.rs
+++ b/src/net/unix/split.rs
@@ -8,20 +8,40 @@
//! split has no associated overhead and enforces all invariants at the type
//! level.
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::UnixStream;
use std::io;
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
-/// Read half of a `UnixStream`.
+/// Borrowed read half of a [`UnixStream`], created by [`split`].
+///
+/// Reading from a `ReadHalf` is usually done using the convenience methods found on the
+/// [`AsyncReadExt`] trait. Examples import this trait through [the prelude].
+///
+/// [`UnixStream`]: UnixStream
+/// [`split`]: UnixStream::split()
+/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
+/// [the prelude]: crate::prelude
#[derive(Debug)]
pub struct ReadHalf<'a>(&'a UnixStream);
-/// Write half of a `UnixStream`.
+/// Borrowed write half of a [`UnixStream`], created by [`split`].
+///
+/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will
+/// shut down the UnixStream stream in the write direction.
+///
+/// Writing to an `WriteHalf` is usually done using the convenience methods found
+/// on the [`AsyncWriteExt`] trait. Examples import this trait through [the prelude].
+///
+/// [`UnixStream`]: UnixStream
+/// [`split`]: UnixStream::split()
+/// [`AsyncWrite`]: trait@crate::io::AsyncWrite
+/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
+/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
+/// [the prelude]: crate::prelude
#[derive(Debug)]
pub struct WriteHalf<'a>(&'a UnixStream);
@@ -30,15 +50,11 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
}
impl AsyncRead for ReadHalf<'_> {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.0.poll_read_priv(cx, buf)
}
}
diff --git a/src/net/unix/split_owned.rs b/src/net/unix/split_owned.rs
new file mode 100644
index 0000000..ab23307
--- /dev/null
+++ b/src/net/unix/split_owned.rs
@@ -0,0 +1,182 @@
+//! `UnixStream` owned split support.
+//!
+//! A `UnixStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
+//! with the `UnixStream::into_split` method. `OwnedReadHalf` implements
+//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
+//!
+//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
+//! split has no associated overhead and enforces all invariants at the type
+//! level.
+
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
+use crate::net::UnixStream;
+
+use std::error::Error;
+use std::net::Shutdown;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::{fmt, io};
+
+/// Owned read half of a [`UnixStream`], created by [`into_split`].
+///
+/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
+/// on the [`AsyncReadExt`] trait. Examples import this trait through [the prelude].
+///
+/// [`UnixStream`]: crate::net::UnixStream
+/// [`into_split`]: crate::net::UnixStream::into_split()
+/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
+/// [the prelude]: crate::prelude
+#[derive(Debug)]
+pub struct OwnedReadHalf {
+ inner: Arc<UnixStream>,
+}
+
+/// Owned write half of a [`UnixStream`], created by [`into_split`].
+///
+/// Note that in the [`AsyncWrite`] implementation of this type,
+/// [`poll_shutdown`] will shut down the stream in the write direction.
+/// Dropping the write half will also shut down the write half of the stream.
+///
+/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods
+/// found on the [`AsyncWriteExt`] trait. Examples import this trait through
+/// [the prelude].
+///
+/// [`UnixStream`]: crate::net::UnixStream
+/// [`into_split`]: crate::net::UnixStream::into_split()
+/// [`AsyncWrite`]: trait@crate::io::AsyncWrite
+/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
+/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
+/// [the prelude]: crate::prelude
+#[derive(Debug)]
+pub struct OwnedWriteHalf {
+ inner: Arc<UnixStream>,
+ shutdown_on_drop: bool,
+}
+
+pub(crate) fn split_owned(stream: UnixStream) -> (OwnedReadHalf, OwnedWriteHalf) {
+ let arc = Arc::new(stream);
+ let read = OwnedReadHalf {
+ inner: Arc::clone(&arc),
+ };
+ let write = OwnedWriteHalf {
+ inner: arc,
+ shutdown_on_drop: true,
+ };
+ (read, write)
+}
+
+pub(crate) fn reunite(
+ read: OwnedReadHalf,
+ write: OwnedWriteHalf,
+) -> Result<UnixStream, ReuniteError> {
+ if Arc::ptr_eq(&read.inner, &write.inner) {
+ write.forget();
+ // This unwrap cannot fail as the api does not allow creating more than two Arcs,
+ // and we just dropped the other half.
+ Ok(Arc::try_unwrap(read.inner).expect("UnixStream: try_unwrap failed in reunite"))
+ } else {
+ Err(ReuniteError(read, write))
+ }
+}
+
+/// Error indicating that two halves were not from the same socket, and thus could
+/// not be reunited.
+#[derive(Debug)]
+pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
+
+impl fmt::Display for ReuniteError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "tried to reunite halves that are not from the same socket"
+ )
+ }
+}
+
+impl Error for ReuniteError {}
+
+impl OwnedReadHalf {
+ /// Attempts to put the two halves of a `UnixStream` back together and
+ /// recover the original socket. Succeeds only if the two halves
+ /// originated from the same call to [`into_split`].
+ ///
+ /// [`into_split`]: crate::net::UnixStream::into_split()
+ pub fn reunite(self, other: OwnedWriteHalf) -> Result<UnixStream, ReuniteError> {
+ reunite(self, other)
+ }
+}
+
+impl AsyncRead for OwnedReadHalf {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ self.inner.poll_read_priv(cx, buf)
+ }
+}
+
+impl OwnedWriteHalf {
+ /// Attempts to put the two halves of a `UnixStream` back together and
+ /// recover the original socket. Succeeds only if the two halves
+ /// originated from the same call to [`into_split`].
+ ///
+ /// [`into_split`]: crate::net::UnixStream::into_split()
+ pub fn reunite(self, other: OwnedReadHalf) -> Result<UnixStream, ReuniteError> {
+ reunite(other, self)
+ }
+
+ /// Destroy the write half, but don't close the write half of the stream
+ /// until the read half is dropped. If the read half has already been
+ /// dropped, this closes the stream.
+ pub fn forget(mut self) {
+ self.shutdown_on_drop = false;
+ drop(self);
+ }
+}
+
+impl Drop for OwnedWriteHalf {
+ fn drop(&mut self) {
+ if self.shutdown_on_drop {
+ let _ = self.inner.shutdown(Shutdown::Write);
+ }
+ }
+}
+
+impl AsyncWrite for OwnedWriteHalf {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.inner.poll_write_priv(cx, buf)
+ }
+
+ #[inline]
+ fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ // flush is a no-op
+ Poll::Ready(Ok(()))
+ }
+
+ // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
+ fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
+ let res = self.inner.shutdown(Shutdown::Write);
+ if res.is_ok() {
+ Pin::into_inner(self).shutdown_on_drop = false;
+ }
+ res.into()
+ }
+}
+
+impl AsRef<UnixStream> for OwnedReadHalf {
+ fn as_ref(&self) -> &UnixStream {
+ &*self.inner
+ }
+}
+
+impl AsRef<UnixStream> for OwnedWriteHalf {
+ fn as_ref(&self) -> &UnixStream {
+ &*self.inner
+ }
+}
diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs
index beae699..5138077 100644
--- a/src/net/unix/stream.rs
+++ b/src/net/unix/stream.rs
@@ -1,27 +1,28 @@
use crate::future::poll_fn;
-use crate::io::{AsyncRead, AsyncWrite, PollEvented};
+use crate::io::{AsyncRead, AsyncWrite, PollEvented, ReadBuf};
use crate::net::unix::split::{split, ReadHalf, WriteHalf};
+use crate::net::unix::split_owned::{split_owned, OwnedReadHalf, OwnedWriteHalf};
use crate::net::unix::ucred::{self, UCred};
+use crate::net::unix::SocketAddr;
use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write};
-use std::mem::MaybeUninit;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, RawFd};
-use std::os::unix::net::{self, SocketAddr};
+use std::os::unix::net;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_uds! {
+cfg_net_unix! {
/// A structure representing a connected Unix socket.
///
/// This socket can be connected directly with `UnixStream::connect` or accepted
/// from a listener with `UnixListener::incoming`. Additionally, a pair of
/// anonymous Unix sockets can be created with `UnixStream::pair`.
pub struct UnixStream {
- io: PollEvented<mio_uds::UnixStream>,
+ io: PollEvented<mio::net::UnixStream>,
}
}
@@ -35,7 +36,7 @@ impl UnixStream {
where
P: AsRef<Path>,
{
- let stream = mio_uds::UnixStream::connect(path)?;
+ let stream = mio::net::UnixStream::connect(path)?;
let stream = UnixStream::new(stream)?;
poll_fn(|cx| stream.io.poll_write_ready(cx)).await?;
@@ -54,9 +55,9 @@ impl UnixStream {
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
pub fn from_std(stream: net::UnixStream) -> io::Result<UnixStream> {
- let stream = mio_uds::UnixStream::from_stream(stream)?;
+ let stream = mio::net::UnixStream::from_std(stream);
let io = PollEvented::new(stream)?;
Ok(UnixStream { io })
@@ -68,26 +69,26 @@ impl UnixStream {
/// communicating back and forth between one another. Each socket will
/// be associated with the default event loop's handle.
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
- let (a, b) = mio_uds::UnixStream::pair()?;
+ let (a, b) = mio::net::UnixStream::pair()?;
let a = UnixStream::new(a)?;
let b = UnixStream::new(b)?;
Ok((a, b))
}
- pub(crate) fn new(stream: mio_uds::UnixStream) -> io::Result<UnixStream> {
+ pub(crate) fn new(stream: mio::net::UnixStream) -> io::Result<UnixStream> {
let io = PollEvented::new(stream)?;
Ok(UnixStream { io })
}
/// Returns the socket address of the local half of this connection.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
- self.io.get_ref().local_addr()
+ self.io.get_ref().local_addr().map(SocketAddr)
}
/// Returns the socket address of the remote half of this connection.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
- self.io.get_ref().peer_addr()
+ self.io.get_ref().peer_addr().map(SocketAddr)
}
/// Returns effective credentials of the process which called `connect` or `pair`.
@@ -109,24 +110,33 @@ impl UnixStream {
self.io.get_ref().shutdown(how)
}
+ // These lifetime markers also appear in the generated documentation, and make
+ // it more clear that this is a *borrowed* split.
+ #[allow(clippy::needless_lifetimes)]
/// Split a `UnixStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
- pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
+ ///
+ /// This method is more efficient than [`into_split`], but the halves cannot be
+ /// moved into independently spawned tasks.
+ ///
+ /// [`into_split`]: Self::into_split()
+ pub fn split<'a>(&'a mut self) -> (ReadHalf<'a>, WriteHalf<'a>) {
split(self)
}
-}
-impl TryFrom<UnixStream> for mio_uds::UnixStream {
- type Error = io::Error;
-
- /// Consumes value, returning the mio I/O object.
+ /// Splits a `UnixStream` into a read half and a write half, which can be used
+ /// to read and write the stream concurrently.
+ ///
+ /// Unlike [`split`], the owned halves can be moved to separate tasks, however
+ /// this comes at the cost of a heap allocation.
///
- /// See [`PollEvented::into_inner`] for more details about
- /// resource deregistration that happens during the call.
+ /// **Note:** Dropping the write half will shut down the write half of the
+ /// stream. This is equivalent to calling [`shutdown(Write)`] on the `UnixStream`.
///
- /// [`PollEvented::into_inner`]: crate::io::PollEvented::into_inner
- fn try_from(value: UnixStream) -> Result<Self, Self::Error> {
- value.io.into_inner()
+ /// [`split`]: Self::split()
+ /// [`shutdown(Write)`]: fn@Self::shutdown
+ pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
+ split_owned(self)
}
}
@@ -143,15 +153,11 @@ impl TryFrom<net::UnixStream> for UnixStream {
}
impl AsyncRead for UnixStream {
- unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.poll_read_priv(cx, buf)
}
}
@@ -190,16 +196,30 @@ impl UnixStream {
pub(crate) fn poll_read_priv(
&self,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
-
- match self.io.get_ref().read(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_read_ready(cx, mio::Ready::readable())?;
- Poll::Pending
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ loop {
+ let ev = ready!(self.io.poll_read_ready(cx))?;
+
+ // Safety: `UnixStream::read` will not peek at the maybe uinitialized bytes.
+ let b = unsafe {
+ &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
+ };
+ match self.io.get_ref().read(b) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ Ok(n) => {
+ // Safety: We trust `UnixStream::read` to have filled up `n` bytes
+ // in the buffer.
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.advance(n);
+ return Poll::Ready(Ok(()));
+ }
+ Err(e) => return Poll::Ready(Err(e)),
}
- x => Poll::Ready(x),
}
}
@@ -208,14 +228,15 @@ impl UnixStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
- ready!(self.io.poll_write_ready(cx))?;
-
- match self.io.get_ref().write(buf) {
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- self.io.clear_write_ready(cx)?;
- Poll::Pending
+ loop {
+ let ev = ready!(self.io.poll_write_ready(cx))?;
+
+ match self.io.get_ref().write(buf) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.clear_readiness(ev);
+ }
+ x => return Poll::Ready(x),
}
- x => Poll::Ready(x),
}
}
}
diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs
index 466aedc..ef214a7 100644
--- a/src/net/unix/ucred.rs
+++ b/src/net/unix/ucred.rs
@@ -4,9 +4,21 @@ use libc::{gid_t, uid_t};
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct UCred {
/// UID (user ID) of the process
- pub uid: uid_t,
+ uid: uid_t,
/// GID (group ID) of the process
- pub gid: gid_t,
+ gid: gid_t,
+}
+
+impl UCred {
+ /// Gets UID (user ID) of the process.
+ pub fn uid(&self) -> uid_t {
+ self.uid
+ }
+
+ /// Gets GID (group ID) of the process.
+ pub fn gid(&self) -> gid_t {
+ self.gid
+ }
}
#[cfg(any(target_os = "linux", target_os = "android"))]
diff --git a/src/park/either.rs b/src/park/either.rs
index 67f1e17..ee02ec1 100644
--- a/src/park/either.rs
+++ b/src/park/either.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "full"), allow(dead_code))]
+
use crate::park::{Park, Unpark};
use std::fmt;
@@ -36,6 +38,13 @@ where
Either::B(b) => b.park_timeout(duration).map_err(Either::B),
}
}
+
+ fn shutdown(&mut self) {
+ match self {
+ Either::A(a) => a.shutdown(),
+ Either::B(b) => b.shutdown(),
+ }
+ }
}
impl<A, B> Unpark for Either<A, B>
diff --git a/src/park/mod.rs b/src/park/mod.rs
index 04d3051..5db26ce 100644
--- a/src/park/mod.rs
+++ b/src/park/mod.rs
@@ -34,17 +34,12 @@
//! * `park_timeout` does the same as `park` but allows specifying a maximum
//! time to block the thread for.
-cfg_resource_drivers! {
- mod either;
- pub(crate) use self::either::Either;
+cfg_rt! {
+ pub(crate) mod either;
}
-mod thread;
-pub(crate) use self::thread::ParkThread;
-
-cfg_block_on! {
- pub(crate) use self::thread::{CachedParkThread, ParkError};
-}
+#[cfg(any(feature = "rt", feature = "sync"))]
+pub(crate) mod thread;
use std::sync::Arc;
use std::time::Duration;
@@ -88,6 +83,9 @@ pub(crate) trait Park {
/// an implementation detail. Refer to the documentation for the specific
/// `Park` implementation
fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>;
+
+ /// Release all resources holded by the parker for proper leak-free shutdown
+ fn shutdown(&mut self);
}
/// Unblock a thread blocked by the associated `Park` instance.
diff --git a/src/park/thread.rs b/src/park/thread.rs
index 2e2397c..2725e45 100644
--- a/src/park/thread.rs
+++ b/src/park/thread.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "full"), allow(dead_code))]
+
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::park::{Park, Unpark};
@@ -65,6 +67,10 @@ impl Park for ParkThread {
self.inner.park_timeout(duration);
Ok(())
}
+
+ fn shutdown(&mut self) {
+ self.inner.shutdown();
+ }
}
// ==== impl Inner ====
@@ -83,7 +89,7 @@ impl Inner {
}
// Otherwise we need to coordinate going to sleep
- let mut m = self.mutex.lock().unwrap();
+ let mut m = self.mutex.lock();
match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
Ok(_) => {}
@@ -133,7 +139,7 @@ impl Inner {
return;
}
- let m = self.mutex.lock().unwrap();
+ let m = self.mutex.lock();
match self.state.compare_exchange(EMPTY, PARKED, SeqCst, SeqCst) {
Ok(_) => {}
@@ -184,10 +190,14 @@ impl Inner {
// Releasing `lock` before the call to `notify_one` means that when the
// parked thread wakes it doesn't get woken only to have to wait for us
// to release `lock`.
- drop(self.mutex.lock().unwrap());
+ drop(self.mutex.lock());
self.condvar.notify_one()
}
+
+ fn shutdown(&self) {
+ self.condvar.notify_all();
+ }
}
impl Default for ParkThread {
@@ -204,114 +214,133 @@ impl Unpark for UnparkThread {
}
}
-cfg_block_on! {
- use std::marker::PhantomData;
- use std::rc::Rc;
-
- use std::mem;
- use std::task::{RawWaker, RawWakerVTable, Waker};
+use std::future::Future;
+use std::marker::PhantomData;
+use std::mem;
+use std::rc::Rc;
+use std::task::{RawWaker, RawWakerVTable, Waker};
- /// Blocks the current thread using a condition variable.
- #[derive(Debug)]
- pub(crate) struct CachedParkThread {
- _anchor: PhantomData<Rc<()>>,
- }
-
- impl CachedParkThread {
- /// Create a new `ParkThread` handle for the current thread.
- ///
- /// This type cannot be moved to other threads, so it should be created on
- /// the thread that the caller intends to park.
- pub(crate) fn new() -> CachedParkThread {
- CachedParkThread {
- _anchor: PhantomData,
- }
- }
-
- pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> {
- self.with_current(|park_thread| park_thread.unpark())
- }
+/// Blocks the current thread using a condition variable.
+#[derive(Debug)]
+pub(crate) struct CachedParkThread {
+ _anchor: PhantomData<Rc<()>>,
+}
- /// Get a reference to the `ParkThread` handle for this thread.
- fn with_current<F, R>(&self, f: F) -> Result<R, ParkError>
- where
- F: FnOnce(&ParkThread) -> R,
- {
- CURRENT_PARKER.try_with(|inner| f(inner))
- .map_err(|_| ())
+impl CachedParkThread {
+ /// Create a new `ParkThread` handle for the current thread.
+ ///
+ /// This type cannot be moved to other threads, so it should be created on
+ /// the thread that the caller intends to park.
+ pub(crate) fn new() -> CachedParkThread {
+ CachedParkThread {
+ _anchor: PhantomData,
}
}
- impl Park for CachedParkThread {
- type Unpark = UnparkThread;
- type Error = ParkError;
+ pub(crate) fn get_unpark(&self) -> Result<UnparkThread, ParkError> {
+ self.with_current(|park_thread| park_thread.unpark())
+ }
- fn unpark(&self) -> Self::Unpark {
- self.get_unpark().unwrap()
- }
+ /// Get a reference to the `ParkThread` handle for this thread.
+ fn with_current<F, R>(&self, f: F) -> Result<R, ParkError>
+ where
+ F: FnOnce(&ParkThread) -> R,
+ {
+ CURRENT_PARKER.try_with(|inner| f(inner)).map_err(|_| ())
+ }
- fn park(&mut self) -> Result<(), Self::Error> {
- self.with_current(|park_thread| park_thread.inner.park())?;
- Ok(())
- }
+ pub(crate) fn block_on<F: Future>(&mut self, f: F) -> Result<F::Output, ParkError> {
+ use std::task::Context;
+ use std::task::Poll::Ready;
- fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
- self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?;
- Ok(())
- }
- }
+ // `get_unpark()` should not return a Result
+ let waker = self.get_unpark()?.into_waker();
+ let mut cx = Context::from_waker(&waker);
+ pin!(f);
- impl UnparkThread {
- pub(crate) fn into_waker(self) -> Waker {
- unsafe {
- let raw = unparker_to_raw_waker(self.inner);
- Waker::from_raw(raw)
+ loop {
+ if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
+ return Ok(v);
}
+
+ self.park()?;
}
}
+}
- impl Inner {
- #[allow(clippy::wrong_self_convention)]
- fn into_raw(this: Arc<Inner>) -> *const () {
- Arc::into_raw(this) as *const ()
- }
+impl Park for CachedParkThread {
+ type Unpark = UnparkThread;
+ type Error = ParkError;
- unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
- Arc::from_raw(ptr as *const Inner)
- }
+ fn unpark(&self) -> Self::Unpark {
+ self.get_unpark().unwrap()
}
- unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
- RawWaker::new(
- Inner::into_raw(unparker),
- &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
- )
+ fn park(&mut self) -> Result<(), Self::Error> {
+ self.with_current(|park_thread| park_thread.inner.park())?;
+ Ok(())
}
- unsafe fn clone(raw: *const ()) -> RawWaker {
- let unparker = Inner::from_raw(raw);
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
+ self.with_current(|park_thread| park_thread.inner.park_timeout(duration))?;
+ Ok(())
+ }
- // Increment the ref count
- mem::forget(unparker.clone());
+ fn shutdown(&mut self) {
+ let _ = self.with_current(|park_thread| park_thread.inner.shutdown());
+ }
+}
- unparker_to_raw_waker(unparker)
+impl UnparkThread {
+ pub(crate) fn into_waker(self) -> Waker {
+ unsafe {
+ let raw = unparker_to_raw_waker(self.inner);
+ Waker::from_raw(raw)
+ }
}
+}
- unsafe fn drop_waker(raw: *const ()) {
- let _ = Inner::from_raw(raw);
+impl Inner {
+ #[allow(clippy::wrong_self_convention)]
+ fn into_raw(this: Arc<Inner>) -> *const () {
+ Arc::into_raw(this) as *const ()
}
- unsafe fn wake(raw: *const ()) {
- let unparker = Inner::from_raw(raw);
- unparker.unpark();
+ unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
+ Arc::from_raw(ptr as *const Inner)
}
+}
- unsafe fn wake_by_ref(raw: *const ()) {
- let unparker = Inner::from_raw(raw);
- unparker.unpark();
+unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
+ RawWaker::new(
+ Inner::into_raw(unparker),
+ &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
+ )
+}
- // We don't actually own a reference to the unparker
- mem::forget(unparker);
- }
+unsafe fn clone(raw: *const ()) -> RawWaker {
+ let unparker = Inner::from_raw(raw);
+
+ // Increment the ref count
+ mem::forget(unparker.clone());
+
+ unparker_to_raw_waker(unparker)
+}
+
+unsafe fn drop_waker(raw: *const ()) {
+ let _ = Inner::from_raw(raw);
+}
+
+unsafe fn wake(raw: *const ()) {
+ let unparker = Inner::from_raw(raw);
+ unparker.unpark();
+}
+
+unsafe fn wake_by_ref(raw: *const ()) {
+ let unparker = Inner::from_raw(raw);
+ unparker.unpark();
+
+ // We don't actually own a reference to the unparker
+ mem::forget(unparker);
}
diff --git a/src/process/mod.rs b/src/process/mod.rs
index e04a435..ad64371 100644
--- a/src/process/mod.rs
+++ b/src/process/mod.rs
@@ -18,16 +18,15 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! // The usage is the same as with the standard library's `Command` type, however the value
-//! // returned from `spawn` is a `Result` containing a `Future`.
-//! let child = Command::new("echo").arg("hello").arg("world")
-//! .spawn();
+//! // The usage is similar as with the standard library's `Command` type
+//! let mut child = Command::new("echo")
+//! .arg("hello")
+//! .arg("world")
+//! .spawn()
+//! .expect("failed to spawn");
//!
-//! // Make sure our child succeeded in spawning and process the result
-//! let future = child.expect("failed to spawn");
-//!
-//! // Await until the future (and the command) completes
-//! let status = future.await?;
+//! // Await until the command completes
+//! let status = child.wait().await?;
//! println!("the command exited with: {}", status);
//! Ok(())
//! }
@@ -83,8 +82,8 @@
//!
//! // Ensure the child process is spawned in the runtime so it can
//! // make progress on its own while we await for any output.
-//! tokio::spawn(async {
-//! let status = child.await
+//! tokio::spawn(async move {
+//! let status = child.wait().await
//! .expect("child process encountered an error");
//!
//! println!("child status was: {}", status);
@@ -100,27 +99,52 @@
//!
//! # Caveats
//!
+//! ## Dropping/Cancellation
+//!
//! Similar to the behavior to the standard library, and unlike the futures
//! paradigm of dropping-implies-cancellation, a spawned process will, by
//! default, continue to execute even after the `Child` handle has been dropped.
//!
-//! The `Command::kill_on_drop` method can be used to modify this behavior
+//! The [`Command::kill_on_drop`] method can be used to modify this behavior
//! and kill the child process if the `Child` wrapper is dropped before it
//! has exited.
//!
+//! ## Unix Processes
+//!
+//! On Unix platforms processes must be "reaped" by their parent process after
+//! they have exited in order to release all OS resources. A child process which
+//! has exited, but has not yet been reaped by its parent is considered a "zombie"
+//! process. Such processes continue to count against limits imposed by the system,
+//! and having too many zombie processes present can prevent additional processes
+//! from being spawned.
+//!
+//! The tokio runtime will, on a best-effort basis, attempt to reap and clean up
+//! any process which it has spawned. No additional guarantees are made with regards
+//! how quickly or how often this procedure will take place.
+//!
+//! It is recommended to avoid dropping a [`Child`] process handle before it has been
+//! fully `await`ed if stricter cleanup guarantees are required.
+//!
//! [`Command`]: crate::process::Command
+//! [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop
+//! [`Child`]: crate::process::Child
#[path = "unix/mod.rs"]
#[cfg(unix)]
mod imp;
+#[cfg(unix)]
+pub(crate) mod unix {
+ pub(crate) use super::imp::*;
+}
+
#[path = "windows.rs"]
#[cfg(windows)]
mod imp;
mod kill;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::process::kill::Kill;
use std::ffi::OsStr;
@@ -450,6 +474,26 @@ impl Command {
/// By default, this value is assumed to be `false`, meaning the next spawned
/// process will not be killed on drop, similar to the behavior of the standard
/// library.
+ ///
+ /// # Caveats
+ ///
+ /// On Unix platforms processes must be "reaped" by their parent process after
+ /// they have exited in order to release all OS resources. A child process which
+ /// has exited, but has not yet been reaped by its parent is considered a "zombie"
+ /// process. Such processes continue to count against limits imposed by the system,
+ /// and having too many zombie processes present can prevent additional processes
+ /// from being spawned.
+ ///
+ /// Although issuing a `kill` signal to the child process is a synchronous
+ /// operation, the resulting zombie process cannot be `.await`ed inside of the
+ /// destructor to avoid blocking other tasks. The tokio runtime will, on a
+ /// best-effort basis, attempt to reap and clean up such processes in the
+ /// background, but makes no additional guarantees are made with regards
+ /// how quickly or how often this procedure will take place.
+ ///
+ /// If stronger guarantees are required, it is recommended to avoid dropping
+ /// a [`Child`] handle where possible, and instead utilize `child.wait().await`
+ /// or `child.kill().await` where possible.
pub fn kill_on_drop(&mut self, kill_on_drop: bool) -> &mut Command {
self.kill_on_drop = kill_on_drop;
self
@@ -534,16 +578,6 @@ impl Command {
/// All I/O this child does will be associated with the current default
/// event loop.
///
- /// # Caveats
- ///
- /// Similar to the behavior to the standard library, and unlike the futures
- /// paradigm of dropping-implies-cancellation, the spawned process will, by
- /// default, continue to execute even after the `Child` handle has been dropped.
- ///
- /// The `Command::kill_on_drop` method can be used to modify this behavior
- /// and kill the child process if the `Child` wrapper is dropped before it
- /// has exited.
- ///
/// # Examples
///
/// Basic usage:
@@ -555,16 +589,55 @@ impl Command {
/// Command::new("ls")
/// .spawn()
/// .expect("ls command failed to start")
+ /// .wait()
/// .await
/// .expect("ls command failed to run")
/// }
/// ```
+ ///
+ /// # Caveats
+ ///
+ /// ## Dropping/Cancellation
+ ///
+ /// Similar to the behavior to the standard library, and unlike the futures
+ /// paradigm of dropping-implies-cancellation, a spawned process will, by
+ /// default, continue to execute even after the `Child` handle has been dropped.
+ ///
+ /// The [`Command::kill_on_drop`] method can be used to modify this behavior
+ /// and kill the child process if the `Child` wrapper is dropped before it
+ /// has exited.
+ ///
+ /// ## Unix Processes
+ ///
+ /// On Unix platforms processes must be "reaped" by their parent process after
+ /// they have exited in order to release all OS resources. A child process which
+ /// has exited, but has not yet been reaped by its parent is considered a "zombie"
+ /// process. Such processes continue to count against limits imposed by the system,
+ /// and having too many zombie processes present can prevent additional processes
+ /// from being spawned.
+ ///
+ /// The tokio runtime will, on a best-effort basis, attempt to reap and clean up
+ /// any process which it has spawned. No additional guarantees are made with regards
+ /// how quickly or how often this procedure will take place.
+ ///
+ /// It is recommended to avoid dropping a [`Child`] process handle before it has been
+ /// fully `await`ed if stricter cleanup guarantees are required.
+ ///
+ /// [`Command`]: crate::process::Command
+ /// [`Command::kill_on_drop`]: crate::process::Command::kill_on_drop
+ /// [`Child`]: crate::process::Child
+ ///
+ /// # Errors
+ ///
+ /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock`
+ /// if the system process limit is reached (which includes other applications
+ /// running on the system).
pub fn spawn(&mut self) -> io::Result<Child> {
imp::spawn_child(&mut self.std).map(|spawned_child| Child {
- child: ChildDropGuard {
+ child: FusedChild::Child(ChildDropGuard {
inner: spawned_child.child,
kill_on_drop: self.kill_on_drop,
- },
+ }),
stdin: spawned_child.stdin.map(|inner| ChildStdin { inner }),
stdout: spawned_child.stdout.map(|inner| ChildStdout { inner }),
stderr: spawned_child.stderr.map(|inner| ChildStderr { inner }),
@@ -581,14 +654,20 @@ impl Command {
/// All I/O this child does will be associated with the current default
/// event loop.
///
- /// If this future is dropped before the future resolves, then
- /// the child will be killed, if it was spawned.
+ /// The destructor of the future returned by this function will kill
+ /// the child if [`kill_on_drop`] is set to true.
+ ///
+ /// [`kill_on_drop`]: fn@Self::kill_on_drop
///
/// # Errors
///
/// This future will return an error if the child process cannot be spawned
/// or if there is an error while awaiting its status.
///
+ /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock`
+ /// if the system process limit is reached (which includes other applications
+ /// running on the system).
+ ///
/// # Examples
///
/// Basic usage:
@@ -602,6 +681,7 @@ impl Command {
/// .await
/// .expect("ls command failed to run")
/// }
+ /// ```
pub fn status(&mut self) -> impl Future<Output = io::Result<ExitStatus>> {
let child = self.spawn();
@@ -615,7 +695,7 @@ impl Command {
child.stdout.take();
child.stderr.take();
- child.await
+ child.wait().await
}
}
@@ -637,9 +717,19 @@ impl Command {
/// All I/O this child does will be associated with the current default
/// event loop.
///
- /// If this future is dropped before the future resolves, then
- /// the child will be killed, if it was spawned.
+ /// The destructor of the future returned by this function will kill
+ /// the child if [`kill_on_drop`] is set to true.
+ ///
+ /// [`kill_on_drop`]: fn@Self::kill_on_drop
+ ///
+ /// # Errors
+ ///
+ /// This future will return an error if the child process cannot be spawned
+ /// or if there is an error while awaiting its status.
///
+ /// On Unix platforms this method will fail with `std::io::ErrorKind::WouldBlock`
+ /// if the system process limit is reached (which includes other applications
+ /// running on the system).
/// # Examples
///
/// Basic usage:
@@ -654,6 +744,7 @@ impl Command {
/// .expect("ls command failed to run");
/// println!("stderr of ls: {:?}", output.stderr);
/// }
+ /// ```
pub fn output(&mut self) -> impl Future<Output = io::Result<Output>> {
self.std.stdout(Stdio::piped());
self.std.stderr(Stdio::piped());
@@ -725,12 +816,16 @@ where
}
}
+/// Keeps track of the exit status of a child process without worrying about
+/// polling the underlying futures even after they have completed.
+#[derive(Debug)]
+enum FusedChild {
+ Child(ChildDropGuard<imp::Child>),
+ Done(ExitStatus),
+}
+
/// Representation of a child process spawned onto an event loop.
///
-/// This type is also a future which will yield the `ExitStatus` of the
-/// underlying child process. A `Child` here also provides access to information
-/// like the OS-assigned identifier and the stdio streams.
-///
/// # Caveats
/// Similar to the behavior to the standard library, and unlike the futures
/// paradigm of dropping-implies-cancellation, a spawned process will, by
@@ -739,10 +834,9 @@ where
/// The `Command::kill_on_drop` method can be used to modify this behavior
/// and kill the child process if the `Child` wrapper is dropped before it
/// has exited.
-#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Child {
- child: ChildDropGuard<imp::Child>,
+ child: FusedChild,
/// The handle for writing to the child's standard input (stdin), if it has
/// been captured.
@@ -758,34 +852,120 @@ pub struct Child {
}
impl Child {
- /// Returns the OS-assigned process identifier associated with this child.
- pub fn id(&self) -> u32 {
- self.child.inner.id()
+ /// Returns the OS-assigned process identifier associated with this child
+ /// while it is still running.
+ ///
+ /// Once the child has been polled to completion this will return `None`.
+ /// This is done to avoid confusion on platforms like Unix where the OS
+ /// identifier could be reused once the process has completed.
+ pub fn id(&self) -> Option<u32> {
+ match &self.child {
+ FusedChild::Child(child) => Some(child.inner.id()),
+ FusedChild::Done(_) => None,
+ }
+ }
+
+ /// Attempts to force the child to exit, but does not wait for the request
+ /// to take effect.
+ ///
+ /// On Unix platforms, this is the equivalent to sending a SIGKILL. Note
+ /// that on Unix platforms it is possible for a zombie process to remain
+ /// after a kill is sent; to avoid this, the caller should ensure that either
+ /// `child.wait().await` or `child.try_wait()` is invoked successfully.
+ pub fn start_kill(&mut self) -> io::Result<()> {
+ match &mut self.child {
+ FusedChild::Child(child) => child.kill(),
+ FusedChild::Done(_) => Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "invalid argument: can't kill an exited process",
+ )),
+ }
}
/// Forces the child to exit.
///
/// This is equivalent to sending a SIGKILL on unix platforms.
- pub fn kill(&mut self) -> io::Result<()> {
- self.child.kill()
- }
-
- #[doc(hidden)]
- #[deprecated(note = "please use `child.stdin` instead")]
- pub fn stdin(&mut self) -> &mut Option<ChildStdin> {
- &mut self.stdin
+ ///
+ /// If the child has to be killed remotely, it is possible to do it using
+ /// a combination of the select! macro and a oneshot channel. In the following
+ /// example, the child will run until completion unless a message is sent on
+ /// the oneshot channel. If that happens, the child is killed immediately
+ /// using the `.kill()` method.
+ ///
+ /// ```no_run
+ /// use tokio::process::Command;
+ /// use tokio::sync::oneshot::channel;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (send, recv) = channel::<()>();
+ /// let mut child = Command::new("sleep").arg("1").spawn().unwrap();
+ /// tokio::spawn(async move { send.send(()) });
+ /// tokio::select! {
+ /// _ = child.wait() => {}
+ /// _ = recv => child.kill().await.expect("kill failed"),
+ /// }
+ /// }
+ /// ```
+ pub async fn kill(&mut self) -> io::Result<()> {
+ self.start_kill()?;
+ self.wait().await?;
+ Ok(())
}
- #[doc(hidden)]
- #[deprecated(note = "please use `child.stdout` instead")]
- pub fn stdout(&mut self) -> &mut Option<ChildStdout> {
- &mut self.stdout
+ /// Waits for the child to exit completely, returning the status that it
+ /// exited with. This function will continue to have the same return value
+ /// after it has been called at least once.
+ ///
+ /// The stdin handle to the child process, if any, will be closed
+ /// before waiting. This helps avoid deadlock: it ensures that the
+ /// child does not block waiting for input from the parent, while
+ /// the parent waits for the child to exit.
+ pub async fn wait(&mut self) -> io::Result<ExitStatus> {
+ match &mut self.child {
+ FusedChild::Done(exit) => Ok(*exit),
+ FusedChild::Child(child) => {
+ let ret = child.await;
+
+ if let Ok(exit) = ret {
+ self.child = FusedChild::Done(exit);
+ }
+
+ ret
+ }
+ }
}
- #[doc(hidden)]
- #[deprecated(note = "please use `child.stderr` instead")]
- pub fn stderr(&mut self) -> &mut Option<ChildStderr> {
- &mut self.stderr
+ /// Attempts to collect the exit status of the child if it has already
+ /// exited.
+ ///
+ /// This function will not block the calling thread and will only
+ /// check to see if the child process has exited or not. If the child has
+ /// exited then on Unix the process ID is reaped. This function is
+ /// guaranteed to repeatedly return a successful exit status so long as the
+ /// child has already exited.
+ ///
+ /// If the child has exited, then `Ok(Some(status))` is returned. If the
+ /// exit status is not available at this time then `Ok(None)` is returned.
+ /// If an error occurs, then that error is returned.
+ ///
+ /// Note that unlike `wait`, this function will not attempt to drop stdin,
+ /// nor will it wake the current task if the child exits.
+ pub fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
+ match &mut self.child {
+ FusedChild::Done(exit) => Ok(Some(*exit)),
+ FusedChild::Child(guard) => {
+ let ret = guard.inner.try_wait();
+
+ if let Ok(Some(exit)) = ret {
+ // Avoid the overhead of trying to kill a reaped process
+ guard.kill_on_drop = false;
+ self.child = FusedChild::Done(exit);
+ }
+
+ ret
+ }
+ }
}
/// Returns a future that will resolve to an `Output`, containing the exit
@@ -819,7 +999,7 @@ impl Child {
let stdout_fut = read_to_end(self.stdout.take());
let stderr_fut = read_to_end(self.stderr.take());
- let (status, stdout, stderr) = try_join3(self, stdout_fut, stderr_fut).await?;
+ let (status, stdout, stderr) = try_join3(self.wait(), stdout_fut, stderr_fut).await?;
Ok(Output {
status,
@@ -829,14 +1009,6 @@ impl Child {
}
}
-impl Future for Child {
- type Output = io::Result<ExitStatus>;
-
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- Pin::new(&mut self.child).poll(cx)
- }
-}
-
/// The standard input stream for spawned children.
///
/// This type implements the `AsyncWrite` trait to pass data to the stdin handle of
@@ -883,31 +1055,21 @@ impl AsyncWrite for ChildStdin {
}
impl AsyncRead for ChildStdout {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L314
- false
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncRead for ChildStderr {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/process.rs#L375
- false
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
diff --git a/src/process/unix/driver.rs b/src/process/unix/driver.rs
new file mode 100644
index 0000000..9a16cad
--- /dev/null
+++ b/src/process/unix/driver.rs
@@ -0,0 +1,156 @@
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
+//! Process driver
+
+use crate::park::Park;
+use crate::process::unix::orphan::ReapOrphanQueue;
+use crate::process::unix::GlobalOrphanQueue;
+use crate::signal::unix::driver::Driver as SignalDriver;
+use crate::signal::unix::{signal_with_handle, InternalStream, Signal, SignalKind};
+use crate::sync::mpsc::error::TryRecvError;
+
+use std::io;
+use std::time::Duration;
+
+/// Responsible for cleaning up orphaned child processes on Unix platforms.
+#[derive(Debug)]
+pub(crate) struct Driver {
+ park: SignalDriver,
+ inner: CoreDriver<Signal, GlobalOrphanQueue>,
+}
+
+#[derive(Debug)]
+struct CoreDriver<S, Q> {
+ sigchild: S,
+ orphan_queue: Q,
+}
+
+// ===== impl CoreDriver =====
+
+impl<S, Q> CoreDriver<S, Q>
+where
+ S: InternalStream,
+ Q: ReapOrphanQueue,
+{
+ fn got_signal(&mut self) -> bool {
+ match self.sigchild.try_recv() {
+ Ok(()) => true,
+ Err(TryRecvError::Empty) => false,
+ Err(TryRecvError::Closed) => panic!("signal was deregistered"),
+ }
+ }
+
+ fn process(&mut self) {
+ if self.got_signal() {
+ // Drain all notifications which may have been buffered
+ // so we can try to reap all orphans in one batch
+ while self.got_signal() {}
+
+ self.orphan_queue.reap_orphans();
+ }
+ }
+}
+
+// ===== impl Driver =====
+
+impl Driver {
+ /// Creates a new signal `Driver` instance that delegates wakeups to `park`.
+ pub(crate) fn new(park: SignalDriver) -> io::Result<Self> {
+ let sigchild = signal_with_handle(SignalKind::child(), park.handle())?;
+ let inner = CoreDriver {
+ sigchild,
+ orphan_queue: GlobalOrphanQueue,
+ };
+
+ Ok(Self { park, inner })
+ }
+}
+
+// ===== impl Park for Driver =====
+
+impl Park for Driver {
+ type Unpark = <SignalDriver as Park>::Unpark;
+ type Error = io::Error;
+
+ fn unpark(&self) -> Self::Unpark {
+ self.park.unpark()
+ }
+
+ fn park(&mut self) -> Result<(), Self::Error> {
+ self.park.park()?;
+ self.inner.process();
+ Ok(())
+ }
+
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
+ self.park.park_timeout(duration)?;
+ self.inner.process();
+ Ok(())
+ }
+
+ fn shutdown(&mut self) {
+ self.park.shutdown()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::process::unix::orphan::test::MockQueue;
+ use crate::sync::mpsc::error::TryRecvError;
+ use std::task::{Context, Poll};
+
+ struct MockStream {
+ total_try_recv: usize,
+ values: Vec<Option<()>>,
+ }
+
+ impl MockStream {
+ fn new(values: Vec<Option<()>>) -> Self {
+ Self {
+ total_try_recv: 0,
+ values,
+ }
+ }
+ }
+
+ impl InternalStream for MockStream {
+ fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll<Option<()>> {
+ unimplemented!();
+ }
+
+ fn try_recv(&mut self) -> Result<(), TryRecvError> {
+ self.total_try_recv += 1;
+ match self.values.remove(0) {
+ Some(()) => Ok(()),
+ None => Err(TryRecvError::Empty),
+ }
+ }
+ }
+
+ #[test]
+ fn no_reap_if_no_signal() {
+ let mut driver = CoreDriver {
+ sigchild: MockStream::new(vec![None]),
+ orphan_queue: MockQueue::<()>::new(),
+ };
+
+ driver.process();
+
+ assert_eq!(1, driver.sigchild.total_try_recv);
+ assert_eq!(0, driver.orphan_queue.total_reaps.get());
+ }
+
+ #[test]
+ fn coalesce_signals_before_reaping() {
+ let mut driver = CoreDriver {
+ sigchild: MockStream::new(vec![Some(()), Some(()), None]),
+ orphan_queue: MockQueue::<()>::new(),
+ };
+
+ driver.process();
+
+ assert_eq!(3, driver.sigchild.total_try_recv);
+ assert_eq!(1, driver.orphan_queue.total_reaps.get());
+ }
+}
diff --git a/src/process/unix/mod.rs b/src/process/unix/mod.rs
index c25d989..db9d592 100644
--- a/src/process/unix/mod.rs
+++ b/src/process/unix/mod.rs
@@ -21,8 +21,10 @@
//! processes in general aren't scalable (e.g. millions) so it shouldn't be that
//! bad in theory...
-mod orphan;
-use orphan::{OrphanQueue, OrphanQueueImpl, Wait};
+pub(crate) mod driver;
+
+pub(crate) mod orphan;
+use orphan::{OrphanQueue, OrphanQueueImpl, ReapOrphanQueue, Wait};
mod reap;
use reap::Reaper;
@@ -32,19 +34,18 @@ use crate::process::kill::Kill;
use crate::process::SpawnedChild;
use crate::signal::unix::{signal, Signal, SignalKind};
-use mio::event::Evented;
-use mio::unix::{EventedFd, UnixReady};
-use mio::{Poll as MioPoll, PollOpt, Ready, Token};
+use mio::event::Source;
+use mio::unix::SourceFd;
use std::fmt;
use std::future::Future;
use std::io;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
-use std::process::ExitStatus;
+use std::process::{Child as StdChild, ExitStatus};
use std::task::Context;
use std::task::Poll;
-impl Wait for std::process::Child {
+impl Wait for StdChild {
fn id(&self) -> u32 {
self.id()
}
@@ -54,17 +55,17 @@ impl Wait for std::process::Child {
}
}
-impl Kill for std::process::Child {
+impl Kill for StdChild {
fn kill(&mut self) -> io::Result<()> {
self.kill()
}
}
lazy_static::lazy_static! {
- static ref ORPHAN_QUEUE: OrphanQueueImpl<std::process::Child> = OrphanQueueImpl::new();
+ static ref ORPHAN_QUEUE: OrphanQueueImpl<StdChild> = OrphanQueueImpl::new();
}
-struct GlobalOrphanQueue;
+pub(crate) struct GlobalOrphanQueue;
impl fmt::Debug for GlobalOrphanQueue {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -72,19 +73,21 @@ impl fmt::Debug for GlobalOrphanQueue {
}
}
-impl OrphanQueue<std::process::Child> for GlobalOrphanQueue {
- fn push_orphan(&self, orphan: std::process::Child) {
- ORPHAN_QUEUE.push_orphan(orphan)
- }
-
+impl ReapOrphanQueue for GlobalOrphanQueue {
fn reap_orphans(&self) {
ORPHAN_QUEUE.reap_orphans()
}
}
+impl OrphanQueue<StdChild> for GlobalOrphanQueue {
+ fn push_orphan(&self, orphan: StdChild) {
+ ORPHAN_QUEUE.push_orphan(orphan)
+ }
+}
+
#[must_use = "futures do nothing unless polled"]
pub(crate) struct Child {
- inner: Reaper<std::process::Child, GlobalOrphanQueue, Signal>,
+ inner: Reaper<StdChild, GlobalOrphanQueue, Signal>,
}
impl fmt::Debug for Child {
@@ -117,6 +120,10 @@ impl Child {
pub(crate) fn id(&self) -> u32 {
self.inner.id()
}
+
+ pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
+ self.inner.inner_mut().try_wait()
+ }
}
impl Kill for Child {
@@ -169,32 +176,30 @@ where
}
}
-impl<T> Evented for Fd<T>
+impl<T> Source for Fd<T>
where
T: AsRawFd,
{
fn register(
- &self,
- poll: &MioPoll,
- token: Token,
- interest: Ready,
- opts: PollOpt,
+ &mut self,
+ registry: &mio::Registry,
+ token: mio::Token,
+ interest: mio::Interest,
) -> io::Result<()> {
- EventedFd(&self.as_raw_fd()).register(poll, token, interest | UnixReady::hup(), opts)
+ SourceFd(&self.as_raw_fd()).register(registry, token, interest)
}
fn reregister(
- &self,
- poll: &MioPoll,
- token: Token,
- interest: Ready,
- opts: PollOpt,
+ &mut self,
+ registry: &mio::Registry,
+ token: mio::Token,
+ interest: mio::Interest,
) -> io::Result<()> {
- EventedFd(&self.as_raw_fd()).reregister(poll, token, interest | UnixReady::hup(), opts)
+ SourceFd(&self.as_raw_fd()).reregister(registry, token, interest)
}
- fn deregister(&self, poll: &MioPoll) -> io::Result<()> {
- EventedFd(&self.as_raw_fd()).deregister(poll)
+ fn deregister(&mut self, registry: &mio::Registry) -> io::Result<()> {
+ SourceFd(&self.as_raw_fd()).deregister(registry)
}
}
diff --git a/src/process/unix/orphan.rs b/src/process/unix/orphan.rs
index 6c449a9..8a1e127 100644
--- a/src/process/unix/orphan.rs
+++ b/src/process/unix/orphan.rs
@@ -20,23 +20,29 @@ impl<T: Wait> Wait for &mut T {
}
}
-/// An interface for queueing up an orphaned process so that it can be reaped.
-pub(crate) trait OrphanQueue<T> {
- /// Adds an orphan to the queue.
- fn push_orphan(&self, orphan: T);
+/// An interface for reaping a set of orphaned processes.
+pub(crate) trait ReapOrphanQueue {
/// Attempts to reap every process in the queue, ignoring any errors and
/// enqueueing any orphans which have not yet exited.
fn reap_orphans(&self);
}
+impl<T: ReapOrphanQueue> ReapOrphanQueue for &T {
+ fn reap_orphans(&self) {
+ (**self).reap_orphans()
+ }
+}
+
+/// An interface for queueing up an orphaned process so that it can be reaped.
+pub(crate) trait OrphanQueue<T>: ReapOrphanQueue {
+ /// Adds an orphan to the queue.
+ fn push_orphan(&self, orphan: T);
+}
+
impl<T, O: OrphanQueue<T>> OrphanQueue<T> for &O {
fn push_orphan(&self, orphan: T) {
(**self).push_orphan(orphan);
}
-
- fn reap_orphans(&self) {
- (**self).reap_orphans()
- }
}
/// An implementation of `OrphanQueue`.
@@ -62,42 +68,62 @@ impl<T: Wait> OrphanQueue<T> for OrphanQueueImpl<T> {
fn push_orphan(&self, orphan: T) {
self.queue.lock().unwrap().push(orphan)
}
+}
+impl<T: Wait> ReapOrphanQueue for OrphanQueueImpl<T> {
fn reap_orphans(&self) {
let mut queue = self.queue.lock().unwrap();
let queue = &mut *queue;
- let mut i = 0;
- while i < queue.len() {
+ for i in (0..queue.len()).rev() {
match queue[i].try_wait() {
- Ok(Some(_)) => {}
- Err(_) => {
- // TODO: bubble up error some how. Is this an internal bug?
- // Shoudl we panic? Is it OK for this to be silently
- // dropped?
- }
- // Still not done yet
- Ok(None) => {
- i += 1;
- continue;
+ Ok(None) => {}
+ Ok(Some(_)) | Err(_) => {
+ // The stdlib handles interruption errors (EINTR) when polling a child process.
+ // All other errors represent invalid inputs or pids that have already been
+ // reaped, so we can drop the orphan in case an error is raised.
+ queue.swap_remove(i);
}
}
-
- queue.remove(i);
}
}
}
#[cfg(all(test, not(loom)))]
-mod test {
- use super::Wait;
- use super::{OrphanQueue, OrphanQueueImpl};
- use std::cell::Cell;
+pub(crate) mod test {
+ use super::*;
+ use std::cell::{Cell, RefCell};
use std::io;
use std::os::unix::process::ExitStatusExt;
use std::process::ExitStatus;
use std::rc::Rc;
+ pub(crate) struct MockQueue<W> {
+ pub(crate) all_enqueued: RefCell<Vec<W>>,
+ pub(crate) total_reaps: Cell<usize>,
+ }
+
+ impl<W> MockQueue<W> {
+ pub(crate) fn new() -> Self {
+ Self {
+ all_enqueued: RefCell::new(Vec::new()),
+ total_reaps: Cell::new(0),
+ }
+ }
+ }
+
+ impl<W> OrphanQueue<W> for MockQueue<W> {
+ fn push_orphan(&self, orphan: W) {
+ self.all_enqueued.borrow_mut().push(orphan);
+ }
+ }
+
+ impl<W> ReapOrphanQueue for MockQueue<W> {
+ fn reap_orphans(&self) {
+ self.total_reaps.set(self.total_reaps.get() + 1);
+ }
+ }
+
struct MockWait {
total_waits: Rc<Cell<usize>>,
num_wait_until_status: usize,
diff --git a/src/process/unix/reap.rs b/src/process/unix/reap.rs
index 8963805..de483c4 100644
--- a/src/process/unix/reap.rs
+++ b/src/process/unix/reap.rs
@@ -1,6 +1,6 @@
use crate::process::imp::orphan::{OrphanQueue, Wait};
use crate::process::kill::Kill;
-use crate::signal::unix::Signal;
+use crate::signal::unix::InternalStream;
use std::future::Future;
use std::io;
@@ -23,17 +23,6 @@ where
signal: S,
}
-// Work around removal of `futures_core` dependency
-pub(crate) trait Stream: Unpin {
- fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>;
-}
-
-impl Stream for Signal {
- fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
- Signal::poll_recv(self, cx)
- }
-}
-
impl<W, Q, S> Deref for Reaper<W, Q, S>
where
W: Wait + Unpin,
@@ -63,7 +52,7 @@ where
self.inner.as_ref().expect("inner has gone away")
}
- fn inner_mut(&mut self) -> &mut W {
+ pub(crate) fn inner_mut(&mut self) -> &mut W {
self.inner.as_mut().expect("inner has gone away")
}
}
@@ -72,7 +61,7 @@ impl<W, Q, S> Future for Reaper<W, Q, S>
where
W: Wait + Unpin,
Q: OrphanQueue<W> + Unpin,
- S: Stream,
+ S: InternalStream,
{
type Output = io::Result<ExitStatus>;
@@ -80,10 +69,8 @@ where
loop {
// If the child hasn't exited yet, then it's our responsibility to
// ensure the current task gets notified when it might be able to
- // make progress.
- //
- // As described in `spawn` above, we just indicate that we can
- // next make progress once a SIGCHLD is received.
+ // make progress. We can use the delivery of a SIGCHLD signal as a
+ // sign that we can potentially make progress.
//
// However, we will register for a notification on the next signal
// BEFORE we poll the child. Otherwise it is possible that the child
@@ -99,7 +86,6 @@ where
// should not cause significant issues with parent futures.
let registered_interest = self.signal.poll_recv(cx).is_pending();
- self.orphan_queue.reap_orphans();
if let Some(status) = self.inner_mut().try_wait()? {
return Poll::Ready(Ok(status));
}
@@ -147,8 +133,9 @@ where
mod test {
use super::*;
+ use crate::process::unix::orphan::test::MockQueue;
+ use crate::sync::mpsc::error::TryRecvError;
use futures::future::FutureExt;
- use std::cell::{Cell, RefCell};
use std::os::unix::process::ExitStatusExt;
use std::process::ExitStatus;
use std::task::Context;
@@ -211,7 +198,7 @@ mod test {
}
}
- impl Stream for MockStream {
+ impl InternalStream for MockStream {
fn poll_recv(&mut self, _cx: &mut Context<'_>) -> Poll<Option<()>> {
self.total_polls += 1;
match self.values.remove(0) {
@@ -219,29 +206,9 @@ mod test {
None => Poll::Pending,
}
}
- }
- struct MockQueue<W> {
- all_enqueued: RefCell<Vec<W>>,
- total_reaps: Cell<usize>,
- }
-
- impl<W> MockQueue<W> {
- fn new() -> Self {
- Self {
- all_enqueued: RefCell::new(Vec::new()),
- total_reaps: Cell::new(0),
- }
- }
- }
-
- impl<W: Wait> OrphanQueue<W> for MockQueue<W> {
- fn push_orphan(&self, orphan: W) {
- self.all_enqueued.borrow_mut().push(orphan);
- }
-
- fn reap_orphans(&self) {
- self.total_reaps.set(self.total_reaps.get() + 1);
+ fn try_recv(&mut self) -> Result<(), TryRecvError> {
+ unimplemented!();
}
}
@@ -262,7 +229,7 @@ mod test {
assert!(grim.poll_unpin(&mut context).is_pending());
assert_eq!(1, grim.signal.total_polls);
assert_eq!(1, grim.total_waits);
- assert_eq!(1, grim.orphan_queue.total_reaps.get());
+ assert_eq!(0, grim.orphan_queue.total_reaps.get());
assert!(grim.orphan_queue.all_enqueued.borrow().is_empty());
// Not yet exited, couldn't register interest the first time
@@ -270,7 +237,7 @@ mod test {
assert!(grim.poll_unpin(&mut context).is_pending());
assert_eq!(3, grim.signal.total_polls);
assert_eq!(3, grim.total_waits);
- assert_eq!(3, grim.orphan_queue.total_reaps.get());
+ assert_eq!(0, grim.orphan_queue.total_reaps.get());
assert!(grim.orphan_queue.all_enqueued.borrow().is_empty());
// Exited
@@ -283,7 +250,7 @@ mod test {
}
assert_eq!(4, grim.signal.total_polls);
assert_eq!(4, grim.total_waits);
- assert_eq!(4, grim.orphan_queue.total_reaps.get());
+ assert_eq!(0, grim.orphan_queue.total_reaps.get());
assert!(grim.orphan_queue.all_enqueued.borrow().is_empty());
}
diff --git a/src/process/windows.rs b/src/process/windows.rs
index cbe2fa7..1aa6c89 100644
--- a/src/process/windows.rs
+++ b/src/process/windows.rs
@@ -20,24 +20,19 @@ use crate::process::kill::Kill;
use crate::process::SpawnedChild;
use crate::sync::oneshot;
-use mio_named_pipes::NamedPipe;
+use mio::windows::NamedPipe;
use std::fmt;
use std::future::Future;
use std::io;
-use std::os::windows::prelude::*;
-use std::os::windows::process::ExitStatusExt;
+use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle};
use std::pin::Pin;
use std::process::{Child as StdChild, Command as StdCommand, ExitStatus};
use std::ptr;
use std::task::Context;
use std::task::Poll;
-use winapi::shared::minwindef::FALSE;
-use winapi::shared::winerror::WAIT_TIMEOUT;
use winapi::um::handleapi::INVALID_HANDLE_VALUE;
-use winapi::um::processthreadsapi::GetExitCodeProcess;
-use winapi::um::synchapi::WaitForSingleObject;
use winapi::um::threadpoollegacyapiset::UnregisterWaitEx;
-use winapi::um::winbase::{RegisterWaitForSingleObject, INFINITE, WAIT_OBJECT_0};
+use winapi::um::winbase::{RegisterWaitForSingleObject, INFINITE};
use winapi::um::winnt::{BOOLEAN, HANDLE, PVOID, WT_EXECUTEINWAITTHREAD, WT_EXECUTEONLYONCE};
#[must_use = "futures do nothing unless polled"]
@@ -86,6 +81,10 @@ impl Child {
pub(crate) fn id(&self) -> u32 {
self.child.id()
}
+
+ pub(crate) fn try_wait(&mut self) -> io::Result<Option<ExitStatus>> {
+ self.child.try_wait()
+ }
}
impl Kill for Child {
@@ -106,11 +105,11 @@ impl Future for Child {
Poll::Ready(Err(_)) => panic!("should not be canceled"),
Poll::Pending => return Poll::Pending,
}
- let status = try_wait(&inner.child)?.expect("not ready yet");
+ let status = inner.try_wait()?.expect("not ready yet");
return Poll::Ready(Ok(status));
}
- if let Some(e) = try_wait(&inner.child)? {
+ if let Some(e) = inner.try_wait()? {
return Poll::Ready(Ok(e));
}
let (tx, rx) = oneshot::channel();
@@ -157,23 +156,6 @@ unsafe extern "system" fn callback(ptr: PVOID, _timer_fired: BOOLEAN) {
let _ = complete.take().unwrap().send(());
}
-pub(crate) fn try_wait(child: &StdChild) -> io::Result<Option<ExitStatus>> {
- unsafe {
- match WaitForSingleObject(child.as_raw_handle(), 0) {
- WAIT_OBJECT_0 => {}
- WAIT_TIMEOUT => return Ok(None),
- _ => return Err(io::Error::last_os_error()),
- }
- let mut status = 0;
- let rc = GetExitCodeProcess(child.as_raw_handle(), &mut status);
- if rc == FALSE {
- Err(io::Error::last_os_error())
- } else {
- Ok(Some(ExitStatus::from_raw(status)))
- }
- }
-}
-
pub(crate) type ChildStdin = PollEvented<NamedPipe>;
pub(crate) type ChildStdout = PollEvented<NamedPipe>;
pub(crate) type ChildStderr = PollEvented<NamedPipe>;
diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs
index 7e1c257..5ca8467 100644
--- a/src/runtime/basic_scheduler.rs
+++ b/src/runtime/basic_scheduler.rs
@@ -1,22 +1,35 @@
+use crate::future::poll_fn;
+use crate::loom::sync::Mutex;
use crate::park::{Park, Unpark};
-use crate::runtime;
use crate::runtime::task::{self, JoinHandle, Schedule, Task};
-use crate::util::linked_list::LinkedList;
-use crate::util::{waker_ref, Wake};
+use crate::sync::notify::Notify;
+use crate::util::linked_list::{Link, LinkedList};
+use crate::util::{waker_ref, Wake, WakerRef};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
-use std::sync::{Arc, Mutex};
-use std::task::Poll::Ready;
+use std::sync::Arc;
+use std::task::Poll::{Pending, Ready};
use std::time::Duration;
/// Executes tasks on the current thread
-pub(crate) struct BasicScheduler<P>
-where
- P: Park,
-{
+pub(crate) struct BasicScheduler<P: Park> {
+ /// Inner state guarded by a mutex that is shared
+ /// between all `block_on` calls.
+ inner: Mutex<Option<Inner<P>>>,
+
+ /// Notifier for waking up other threads to steal the
+ /// parker.
+ notify: Notify,
+
+ /// Sendable task spawner
+ spawner: Spawner,
+}
+
+/// The inner scheduler that owns the task queue and the main parker P.
+struct Inner<P: Park> {
/// Scheduler run queue
///
/// When the scheduler is executed, the queue is removed from `self` and
@@ -42,7 +55,7 @@ pub(crate) struct Spawner {
struct Tasks {
/// Collection of all active tasks spawned onto this executor.
- owned: LinkedList<Task<Arc<Shared>>>,
+ owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>,
/// Local run queue.
///
@@ -59,7 +72,7 @@ struct Shared {
unpark: Box<dyn Unpark>,
}
-/// Thread-local context
+/// Thread-local context.
struct Context {
/// Shared scheduler state
shared: Arc<Shared>,
@@ -68,38 +81,43 @@ struct Context {
tasks: RefCell<Tasks>,
}
-/// Initial queue capacity
+/// Initial queue capacity.
const INITIAL_CAPACITY: usize = 64;
/// Max number of tasks to poll per tick.
const MAX_TASKS_PER_TICK: usize = 61;
-/// How often ot check the remote queue first
+/// How often to check the remote queue first.
const REMOTE_FIRST_INTERVAL: u8 = 31;
-// Tracks the current BasicScheduler
+// Tracks the current BasicScheduler.
scoped_thread_local!(static CURRENT: Context);
-impl<P> BasicScheduler<P>
-where
- P: Park,
-{
+impl<P: Park> BasicScheduler<P> {
pub(crate) fn new(park: P) -> BasicScheduler<P> {
let unpark = Box::new(park.unpark());
- BasicScheduler {
+ let spawner = Spawner {
+ shared: Arc::new(Shared {
+ queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
+ unpark: unpark as Box<dyn Unpark>,
+ }),
+ };
+
+ let inner = Mutex::new(Some(Inner {
tasks: Some(Tasks {
owned: LinkedList::new(),
queue: VecDeque::with_capacity(INITIAL_CAPACITY),
}),
- spawner: Spawner {
- shared: Arc::new(Shared {
- queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
- unpark: unpark as Box<dyn Unpark>,
- }),
- },
+ spawner: spawner.clone(),
tick: 0,
park,
+ }));
+
+ BasicScheduler {
+ inner,
+ notify: Notify::new(),
+ spawner,
}
}
@@ -116,13 +134,57 @@ where
self.spawner.spawn(future)
}
- pub(crate) fn block_on<F>(&mut self, future: F) -> F::Output
- where
- F: Future,
- {
+ pub(crate) fn block_on<F: Future>(&self, future: F) -> F::Output {
+ pin!(future);
+
+ // Attempt to steal the dedicated parker and block_on the future if we can there,
+ // othwerwise, lets select on a notification that the parker is available
+ // or the future is complete.
+ loop {
+ if let Some(inner) = &mut self.take_inner() {
+ return inner.block_on(future);
+ } else {
+ let mut enter = crate::runtime::enter(false);
+
+ let notified = self.notify.notified();
+ pin!(notified);
+
+ if let Some(out) = enter
+ .block_on(poll_fn(|cx| {
+ if notified.as_mut().poll(cx).is_ready() {
+ return Ready(None);
+ }
+
+ if let Ready(out) = future.as_mut().poll(cx) {
+ return Ready(Some(out));
+ }
+
+ Pending
+ }))
+ .expect("Failed to `Enter::block_on`")
+ {
+ return out;
+ }
+ }
+ }
+ }
+
+ fn take_inner(&self) -> Option<InnerGuard<'_, P>> {
+ let inner = self.inner.lock().take()?;
+
+ Some(InnerGuard {
+ inner: Some(inner),
+ basic_scheduler: &self,
+ })
+ }
+}
+
+impl<P: Park> Inner<P> {
+ /// Block on the future provided and drive the runtime's driver.
+ fn block_on<F: Future>(&mut self, future: F) -> F::Output {
enter(self, |scheduler, context| {
- let _enter = runtime::enter(false);
- let waker = waker_ref(&scheduler.spawner.shared);
+ let _enter = crate::runtime::enter(false);
+ let waker = scheduler.spawner.waker_ref();
let mut cx = std::task::Context::from_waker(&waker);
pin!(future);
@@ -177,16 +239,16 @@ where
/// Enter the scheduler context. This sets the queue and other necessary
/// scheduler state in the thread-local
-fn enter<F, R, P>(scheduler: &mut BasicScheduler<P>, f: F) -> R
+fn enter<F, R, P>(scheduler: &mut Inner<P>, f: F) -> R
where
- F: FnOnce(&mut BasicScheduler<P>, &Context) -> R,
+ F: FnOnce(&mut Inner<P>, &Context) -> R,
P: Park,
{
// Ensures the run queue is placed back in the `BasicScheduler` instance
// once `block_on` returns.`
struct Guard<'a, P: Park> {
context: Option<Context>,
- scheduler: &'a mut BasicScheduler<P>,
+ scheduler: &'a mut Inner<P>,
}
impl<P: Park> Drop for Guard<'_, P> {
@@ -213,12 +275,18 @@ where
CURRENT.set(context, || f(scheduler, context))
}
-impl<P> Drop for BasicScheduler<P>
-where
- P: Park,
-{
+impl<P: Park> Drop for BasicScheduler<P> {
fn drop(&mut self) {
- enter(self, |scheduler, context| {
+ // Avoid a double panic if we are currently panicking and
+ // the lock may be poisoned.
+
+ let mut inner = match self.inner.lock().take() {
+ Some(inner) => inner,
+ None if std::thread::panicking() => return,
+ None => panic!("Oh no! We never placed the Inner state back, this is a bug!"),
+ };
+
+ enter(&mut inner, |scheduler, context| {
// Loop required here to ensure borrow is dropped between iterations
#[allow(clippy::while_let_loop)]
loop {
@@ -236,7 +304,7 @@ where
}
// Drain remote queue
- for task in scheduler.spawner.shared.queue.lock().unwrap().drain(..) {
+ for task in scheduler.spawner.shared.queue.lock().drain(..) {
task.shutdown();
}
@@ -266,7 +334,11 @@ impl Spawner {
}
fn pop(&self) -> Option<task::Notified<Arc<Shared>>> {
- self.shared.queue.lock().unwrap().pop_front()
+ self.shared.queue.lock().pop_front()
+ }
+
+ fn waker_ref(&self) -> WakerRef<'_> {
+ waker_ref(&self.shared)
}
}
@@ -307,7 +379,7 @@ impl Schedule for Arc<Shared> {
cx.tasks.borrow_mut().queue.push_back(task);
}
_ => {
- self.queue.lock().unwrap().push_back(task);
+ self.queue.lock().push_back(task);
self.unpark.unpark();
}
});
@@ -324,3 +396,37 @@ impl Wake for Shared {
arc_self.unpark.unpark();
}
}
+
+// ===== InnerGuard =====
+
+/// Used to ensure we always place the Inner value
+/// back into its slot in `BasicScheduler`, even if the
+/// future panics.
+struct InnerGuard<'a, P: Park> {
+ inner: Option<Inner<P>>,
+ basic_scheduler: &'a BasicScheduler<P>,
+}
+
+impl<P: Park> InnerGuard<'_, P> {
+ fn block_on<F: Future>(&mut self, future: F) -> F::Output {
+ // The only time inner gets set to `None` is if we have dropped
+ // already so this unwrap is safe.
+ self.inner.as_mut().unwrap().block_on(future)
+ }
+}
+
+impl<P: Park> Drop for InnerGuard<'_, P> {
+ fn drop(&mut self) {
+ if let Some(scheduler) = self.inner.take() {
+ let mut lock = self.basic_scheduler.inner.lock();
+
+ // Replace old scheduler back into the state to allow
+ // other threads to pick it up and drive it.
+ lock.replace(scheduler);
+
+ // Wake up other possible threads that could steal
+ // the dedicated parker P.
+ self.basic_scheduler.notify.notify_one()
+ }
+ }
+}
diff --git a/src/runtime/blocking/mod.rs b/src/runtime/blocking/mod.rs
index 0b36a75..fece3c2 100644
--- a/src/runtime/blocking/mod.rs
+++ b/src/runtime/blocking/mod.rs
@@ -3,22 +3,20 @@
//! shells. This isolates the complexity of dealing with conditional
//! compilation.
-cfg_blocking_impl! {
- mod pool;
- pub(crate) use pool::{spawn_blocking, try_spawn_blocking, BlockingPool, Spawner};
+mod pool;
+pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner};
- mod schedule;
- mod shutdown;
- pub(crate) mod task;
+mod schedule;
+mod shutdown;
+pub(crate) mod task;
- use crate::runtime::Builder;
-
- pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool {
- BlockingPool::new(builder, thread_cap)
+use crate::runtime::Builder;
- }
+pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool {
+ BlockingPool::new(builder, thread_cap)
}
+/*
cfg_not_blocking_impl! {
use crate::runtime::Builder;
use std::time::Duration;
@@ -41,3 +39,4 @@ cfg_not_blocking_impl! {
}
}
}
+*/
diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs
index 40d417b..2967a10 100644
--- a/src/runtime/blocking/pool.rs
+++ b/src/runtime/blocking/pool.rs
@@ -5,9 +5,13 @@ use crate::loom::thread;
use crate::runtime::blocking::schedule::NoopSchedule;
use crate::runtime::blocking::shutdown;
use crate::runtime::blocking::task::BlockingTask;
+use crate::runtime::builder::ThreadNameFn;
+use crate::runtime::context;
use crate::runtime::task::{self, JoinHandle};
use crate::runtime::{Builder, Callback, Handle};
+use slab::Slab;
+
use std::collections::VecDeque;
use std::fmt;
use std::time::Duration;
@@ -30,7 +34,7 @@ struct Inner {
condvar: Condvar,
/// Spawned threads use this name
- thread_name: String,
+ thread_name: ThreadNameFn,
/// Spawned thread stack size
stack_size: Option<usize>,
@@ -41,7 +45,11 @@ struct Inner {
/// Call before a thread stops
before_stop: Option<Callback>,
+ // Maximum number of threads
thread_cap: usize,
+
+ // Customizable wait timeout
+ keep_alive: Duration,
}
struct Shared {
@@ -51,6 +59,7 @@ struct Shared {
num_notify: u32,
shutdown: bool,
shutdown_tx: Option<shutdown::Sender>,
+ worker_threads: Slab<thread::JoinHandle<()>>,
}
type Task = task::Notified<NoopSchedule>;
@@ -62,11 +71,8 @@ pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
{
- let rt = Handle::current();
-
- let (task, handle) = task::joinable(BlockingTask::new(func));
- let _ = rt.blocking_spawner.spawn(task, &rt);
- handle
+ let rt = context::current().expect("not currently running on the Tokio runtime.");
+ rt.spawn_blocking(func)
}
#[allow(dead_code)]
@@ -74,7 +80,7 @@ pub(crate) fn try_spawn_blocking<F, R>(func: F) -> Result<(), ()>
where
F: FnOnce() -> R + Send + 'static,
{
- let rt = Handle::current();
+ let rt = context::current().expect("not currently running on the Tokio runtime.");
let (task, _handle) = task::joinable(BlockingTask::new(func));
rt.blocking_spawner.spawn(task, &rt)
@@ -85,6 +91,7 @@ where
impl BlockingPool {
pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
let (shutdown_tx, shutdown_rx) = shutdown::channel();
+ let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
BlockingPool {
spawner: Spawner {
@@ -96,6 +103,7 @@ impl BlockingPool {
num_notify: 0,
shutdown: false,
shutdown_tx: Some(shutdown_tx),
+ worker_threads: Slab::new(),
}),
condvar: Condvar::new(),
thread_name: builder.thread_name.clone(),
@@ -103,6 +111,7 @@ impl BlockingPool {
after_start: builder.after_start.clone(),
before_stop: builder.before_stop.clone(),
thread_cap,
+ keep_alive,
}),
},
shutdown_rx,
@@ -114,7 +123,7 @@ impl BlockingPool {
}
pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
- let mut shared = self.spawner.inner.shared.lock().unwrap();
+ let mut shared = self.spawner.inner.shared.lock();
// The function can be called multiple times. First, by explicitly
// calling `shutdown` then by the drop handler calling `shutdown`. This
@@ -126,10 +135,15 @@ impl BlockingPool {
shared.shutdown = true;
shared.shutdown_tx = None;
self.spawner.inner.condvar.notify_all();
+ let mut workers = std::mem::replace(&mut shared.worker_threads, Slab::new());
drop(shared);
- self.shutdown_rx.wait(timeout);
+ if self.shutdown_rx.wait(timeout) {
+ for handle in workers.drain() {
+ let _ = handle.join();
+ }
+ }
}
}
@@ -150,7 +164,7 @@ impl fmt::Debug for BlockingPool {
impl Spawner {
pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> {
let shutdown_tx = {
- let mut shared = self.inner.shared.lock().unwrap();
+ let mut shared = self.inner.shared.lock();
if shared.shutdown {
// Shutdown the task
@@ -187,14 +201,24 @@ impl Spawner {
};
if let Some(shutdown_tx) = shutdown_tx {
- self.spawn_thread(shutdown_tx, rt);
+ let mut shared = self.inner.shared.lock();
+ let entry = shared.worker_threads.vacant_entry();
+
+ let handle = self.spawn_thread(shutdown_tx, rt, entry.key());
+
+ entry.insert(handle);
}
Ok(())
}
- fn spawn_thread(&self, shutdown_tx: shutdown::Sender, rt: &Handle) {
- let mut builder = thread::Builder::new().name(self.inner.thread_name.clone());
+ fn spawn_thread(
+ &self,
+ shutdown_tx: shutdown::Sender,
+ rt: &Handle,
+ worker_id: usize,
+ ) -> thread::JoinHandle<()> {
+ let mut builder = thread::Builder::new().name((self.inner.thread_name)());
if let Some(stack_size) = self.inner.stack_size {
builder = builder.stack_size(stack_size);
@@ -205,23 +229,21 @@ impl Spawner {
builder
.spawn(move || {
// Only the reference should be moved into the closure
- let rt = &rt;
- rt.enter(move || {
- rt.blocking_spawner.inner.run();
- drop(shutdown_tx);
- })
+ let _enter = crate::runtime::context::enter(rt.clone());
+ rt.blocking_spawner.inner.run(worker_id);
+ drop(shutdown_tx);
})
- .unwrap();
+ .unwrap()
}
}
impl Inner {
- fn run(&self) {
+ fn run(&self, worker_id: usize) {
if let Some(f) = &self.after_start {
f()
}
- let mut shared = self.shared.lock().unwrap();
+ let mut shared = self.shared.lock();
'main: loop {
// BUSY
@@ -229,14 +251,14 @@ impl Inner {
drop(shared);
task.run();
- shared = self.shared.lock().unwrap();
+ shared = self.shared.lock();
}
// IDLE
shared.num_idle += 1;
while !shared.shutdown {
- let lock_result = self.condvar.wait_timeout(shared, KEEP_ALIVE).unwrap();
+ let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
shared = lock_result.0;
let timeout_result = lock_result.1;
@@ -252,6 +274,8 @@ impl Inner {
// Even if the condvar "timed out", if the pool is entering the
// shutdown phase, we want to perform the cleanup logic.
if !shared.shutdown && timeout_result.timed_out() {
+ shared.worker_threads.remove(worker_id);
+
break 'main;
}
@@ -264,7 +288,7 @@ impl Inner {
drop(shared);
task.shutdown();
- shared = self.shared.lock().unwrap();
+ shared = self.shared.lock();
}
// Work was produced, and we "took" it (by decrementing num_notify).
diff --git a/src/runtime/blocking/shutdown.rs b/src/runtime/blocking/shutdown.rs
index e76a701..3b6cc59 100644
--- a/src/runtime/blocking/shutdown.rs
+++ b/src/runtime/blocking/shutdown.rs
@@ -32,11 +32,13 @@ impl Receiver {
/// If `timeout` is `Some`, the thread is blocked for **at most** `timeout`
/// duration. If `timeout` is `None`, then the thread is blocked until the
/// shutdown signal is received.
- pub(crate) fn wait(&mut self, timeout: Option<Duration>) {
+ ///
+ /// If the timeout has elapsed, it returns `false`, otherwise it returns `true`.
+ pub(crate) fn wait(&mut self, timeout: Option<Duration>) -> bool {
use crate::runtime::enter::try_enter;
if timeout == Some(Duration::from_nanos(0)) {
- return;
+ return true;
}
let mut e = match try_enter(false) {
@@ -44,7 +46,7 @@ impl Receiver {
_ => {
if std::thread::panicking() {
// Don't panic in a panic
- return;
+ return false;
} else {
panic!(
"Cannot drop a runtime in a context where blocking is not allowed. \
@@ -60,9 +62,10 @@ impl Receiver {
// current thread (usually, shutting down a runtime stored in a
// thread-local).
if let Some(timeout) = timeout {
- let _ = e.block_on_timeout(&mut self.rx, timeout);
+ e.block_on_timeout(&mut self.rx, timeout).is_ok()
} else {
let _ = e.block_on(&mut self.rx);
+ true
}
}
}
diff --git a/src/runtime/builder.rs b/src/runtime/builder.rs
index fad72c7..e792c7d 100644
--- a/src/runtime/builder.rs
+++ b/src/runtime/builder.rs
@@ -1,23 +1,24 @@
use crate::runtime::handle::Handle;
-use crate::runtime::shell::Shell;
-use crate::runtime::{blocking, io, time, Callback, Runtime, Spawner};
+use crate::runtime::{blocking, driver, Callback, Runtime, Spawner};
use std::fmt;
-#[cfg(not(loom))]
-use std::sync::Arc;
+use std::io;
+use std::time::Duration;
/// Builds Tokio Runtime with custom configuration values.
///
/// Methods can be chained in order to set the configuration values. The
/// Runtime is constructed by calling [`build`].
///
-/// New instances of `Builder` are obtained via [`Builder::new`].
+/// New instances of `Builder` are obtained via [`Builder::new_multi_thread`]
+/// or [`Builder::new_current_thread`].
///
/// See function level documentation for details on the various configuration
/// settings.
///
/// [`build`]: method@Self::build
-/// [`Builder::new`]: method@Self::new
+/// [`Builder::new_multi_thread`]: method@Self::new_multi_thread
+/// [`Builder::new_current_thread`]: method@Self::new_current_thread
///
/// # Examples
///
@@ -26,9 +27,8 @@ use std::sync::Arc;
///
/// fn main() {
/// // build runtime
-/// let runtime = Builder::new()
-/// .threaded_scheduler()
-/// .core_threads(4)
+/// let runtime = Builder::new_multi_thread()
+/// .worker_threads(4)
/// .thread_name("my-custom-name")
/// .thread_stack_size(3 * 1024 * 1024)
/// .build()
@@ -38,7 +38,7 @@ use std::sync::Arc;
/// }
/// ```
pub struct Builder {
- /// The task execution model to use.
+ /// Runtime type
kind: Kind,
/// Whether or not to enable the I/O driver
@@ -50,13 +50,13 @@ pub struct Builder {
/// The number of worker threads, used by Runtime.
///
/// Only used when not using the current-thread executor.
- core_threads: Option<usize>,
+ worker_threads: Option<usize>,
/// Cap on thread usage.
max_threads: usize,
- /// Name used for threads spawned by the runtime.
- pub(super) thread_name: String,
+ /// Name fn used for threads spawned by the runtime.
+ pub(super) thread_name: ThreadNameFn,
/// Stack size used for threads spawned by the runtime.
pub(super) thread_stack_size: Option<usize>,
@@ -66,26 +66,43 @@ pub struct Builder {
/// To run before each worker thread stops
pub(super) before_stop: Option<Callback>,
+
+ /// Customizable keep alive timeout for BlockingPool
+ pub(super) keep_alive: Option<Duration>,
}
-#[derive(Debug, Clone, Copy)]
-enum Kind {
- Shell,
- #[cfg(feature = "rt-core")]
- Basic,
- #[cfg(feature = "rt-threaded")]
- ThreadPool,
+pub(crate) type ThreadNameFn = std::sync::Arc<dyn Fn() -> String + Send + Sync + 'static>;
+
+pub(crate) enum Kind {
+ CurrentThread,
+ #[cfg(feature = "rt-multi-thread")]
+ MultiThread,
}
impl Builder {
+ /// Returns a new builder with the current thread scheduler selected.
+ ///
+ /// Configuration methods can be chained on the return value.
+ pub fn new_current_thread() -> Builder {
+ Builder::new(Kind::CurrentThread)
+ }
+
+ /// Returns a new builder with the multi thread scheduler selected.
+ ///
+ /// Configuration methods can be chained on the return value.
+ #[cfg(feature = "rt-multi-thread")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))]
+ pub fn new_multi_thread() -> Builder {
+ Builder::new(Kind::MultiThread)
+ }
+
/// Returns a new runtime builder initialized with default configuration
/// values.
///
/// Configuration methods can be chained on the return value.
- pub fn new() -> Builder {
+ pub(crate) fn new(kind: Kind) -> Builder {
Builder {
- // No task execution by default
- kind: Kind::Shell,
+ kind,
// I/O defaults to "off"
enable_io: false,
@@ -94,12 +111,12 @@ impl Builder {
enable_time: false,
// Default to lazy auto-detection (one thread per CPU core)
- core_threads: None,
+ worker_threads: None,
max_threads: 512,
// Default thread name
- thread_name: "tokio-runtime-worker".into(),
+ thread_name: std::sync::Arc::new(|| "tokio-runtime-worker".into()),
// Do not set a stack size by default
thread_stack_size: None,
@@ -107,6 +124,8 @@ impl Builder {
// No worker thread callbacks
after_start: None,
before_stop: None,
+
+ keep_alive: None,
}
}
@@ -121,14 +140,13 @@ impl Builder {
/// ```
/// use tokio::runtime;
///
- /// let rt = runtime::Builder::new()
- /// .threaded_scheduler()
+ /// let rt = runtime::Builder::new_multi_thread()
/// .enable_all()
/// .build()
/// .unwrap();
/// ```
pub fn enable_all(&mut self) -> &mut Self {
- #[cfg(feature = "io-driver")]
+ #[cfg(any(feature = "net", feature = "process", all(unix, feature = "signal")))]
self.enable_io();
#[cfg(feature = "time")]
self.enable_time();
@@ -136,51 +154,68 @@ impl Builder {
self
}
- #[deprecated(note = "In future will be replaced by core_threads method")]
- /// Sets the maximum number of worker threads for the `Runtime`'s thread pool.
+ /// Sets the number of worker threads the `Runtime` will use.
+ ///
+ /// This should be a number between 0 and 32,768 though it is advised to
+ /// keep this value on the smaller side.
///
- /// This must be a number between 1 and 32,768 though it is advised to keep
- /// this value on the smaller side.
+ /// # Default
///
/// The default value is the number of cores available to the system.
- pub fn num_threads(&mut self, val: usize) -> &mut Self {
- self.core_threads = Some(val);
- self
- }
-
- /// Sets the core number of worker threads for the `Runtime`'s thread pool.
///
- /// This should be a number between 1 and 32,768 though it is advised to keep
- /// this value on the smaller side.
+ /// # Panic
///
- /// The default value is the number of cores available to the system.
+ /// When using the `current_thread` runtime this method will panic, since
+ /// those variants do not allow setting worker thread counts.
///
- /// These threads will be always active and running.
///
/// # Examples
///
+ /// ## Multi threaded runtime with 4 threads
+ ///
+ /// ```
+ /// use tokio::runtime;
+ ///
+ /// // This will spawn a work-stealing runtime with 4 worker threads.
+ /// let rt = runtime::Builder::new_multi_thread()
+ /// .worker_threads(4)
+ /// .build()
+ /// .unwrap();
+ ///
+ /// rt.spawn(async move {});
+ /// ```
+ ///
+ /// ## Current thread runtime (will only run on the current thread via `Runtime::block_on`)
+ ///
/// ```
/// use tokio::runtime;
///
- /// let rt = runtime::Builder::new()
- /// .threaded_scheduler()
- /// .core_threads(4)
+ /// // Create a runtime that _must_ be driven from a call
+ /// // to `Runtime::block_on`.
+ /// let rt = runtime::Builder::new_current_thread()
/// .build()
/// .unwrap();
+ ///
+ /// // This will run the runtime and future on the current thread
+ /// rt.block_on(async move {});
/// ```
- pub fn core_threads(&mut self, val: usize) -> &mut Self {
- assert_ne!(val, 0, "Core threads cannot be zero");
- self.core_threads = Some(val);
+ ///
+ /// # Panic
+ ///
+ /// This will panic if `val` is not larger than `0`.
+ pub fn worker_threads(&mut self, val: usize) -> &mut Self {
+ assert!(val > 0, "Worker threads cannot be set to 0");
+ self.worker_threads = Some(val);
self
}
/// Specifies limit for threads, spawned by the Runtime.
///
/// This is number of threads to be used by Runtime, including `core_threads`
- /// Having `max_threads` less than `core_threads` results in invalid configuration
+ /// Having `max_threads` less than `worker_threads` results in invalid configuration
/// when building multi-threaded `Runtime`, which would cause a panic.
///
- /// Similarly to the `core_threads`, this number should be between 1 and 32,768.
+ /// Similarly to the `worker_threads`, this number should be between 0 and 32,768.
///
/// The default value is 512.
///
@@ -189,7 +224,6 @@ impl Builder {
/// Otherwise as `core_threads` are always active, it limits additional threads (e.g. for
/// blocking annotations) as `max_threads - core_threads`.
pub fn max_threads(&mut self, val: usize) -> &mut Self {
- assert_ne!(val, 0, "Thread limit cannot be zero");
self.max_threads = val;
self
}
@@ -204,13 +238,42 @@ impl Builder {
/// # use tokio::runtime;
///
/// # pub fn main() {
- /// let rt = runtime::Builder::new()
+ /// let rt = runtime::Builder::new_multi_thread()
/// .thread_name("my-pool")
/// .build();
/// # }
/// ```
pub fn thread_name(&mut self, val: impl Into<String>) -> &mut Self {
- self.thread_name = val.into();
+ let val = val.into();
+ self.thread_name = std::sync::Arc::new(move || val.clone());
+ self
+ }
+
+ /// Sets a function used to generate the name of threads spawned by the `Runtime`'s thread pool.
+ ///
+ /// The default name fn is `|| "tokio-runtime-worker".into()`.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::runtime;
+ /// # use std::sync::atomic::{AtomicUsize, Ordering};
+ ///
+ /// # pub fn main() {
+ /// let rt = runtime::Builder::new_multi_thread()
+ /// .thread_name_fn(|| {
+ /// static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
+ /// let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
+ /// format!("my-pool-{}", id)
+ /// })
+ /// .build();
+ /// # }
+ /// ```
+ pub fn thread_name_fn<F>(&mut self, f: F) -> &mut Self
+ where
+ F: Fn() -> String + Send + Sync + 'static,
+ {
+ self.thread_name = std::sync::Arc::new(f);
self
}
@@ -228,8 +291,7 @@ impl Builder {
/// # use tokio::runtime;
///
/// # pub fn main() {
- /// let rt = runtime::Builder::new()
- /// .threaded_scheduler()
+ /// let rt = runtime::Builder::new_multi_thread()
/// .thread_stack_size(32 * 1024)
/// .build();
/// # }
@@ -250,8 +312,7 @@ impl Builder {
/// # use tokio::runtime;
///
/// # pub fn main() {
- /// let runtime = runtime::Builder::new()
- /// .threaded_scheduler()
+ /// let runtime = runtime::Builder::new_multi_thread()
/// .on_thread_start(|| {
/// println!("thread started");
/// })
@@ -263,7 +324,7 @@ impl Builder {
where
F: Fn() + Send + Sync + 'static,
{
- self.after_start = Some(Arc::new(f));
+ self.after_start = Some(std::sync::Arc::new(f));
self
}
@@ -277,8 +338,7 @@ impl Builder {
/// # use tokio::runtime;
///
/// # pub fn main() {
- /// let runtime = runtime::Builder::new()
- /// .threaded_scheduler()
+ /// let runtime = runtime::Builder::new_multi_thread()
/// .on_thread_stop(|| {
/// println!("thread stopping");
/// })
@@ -290,56 +350,86 @@ impl Builder {
where
F: Fn() + Send + Sync + 'static,
{
- self.before_stop = Some(Arc::new(f));
+ self.before_stop = Some(std::sync::Arc::new(f));
self
}
/// Creates the configured `Runtime`.
///
- /// The returned `ThreadPool` instance is ready to spawn tasks.
+ /// The returned `Runtime` instance is ready to spawn tasks.
///
/// # Examples
///
/// ```
/// use tokio::runtime::Builder;
///
- /// let mut rt = Builder::new().build().unwrap();
+ /// let rt = Builder::new_multi_thread().build().unwrap();
///
/// rt.block_on(async {
/// println!("Hello from the Tokio runtime");
/// });
/// ```
pub fn build(&mut self) -> io::Result<Runtime> {
- match self.kind {
- Kind::Shell => self.build_shell_runtime(),
- #[cfg(feature = "rt-core")]
- Kind::Basic => self.build_basic_runtime(),
- #[cfg(feature = "rt-threaded")]
- Kind::ThreadPool => self.build_threaded_runtime(),
+ match &self.kind {
+ Kind::CurrentThread => self.build_basic_runtime(),
+ #[cfg(feature = "rt-multi-thread")]
+ Kind::MultiThread => self.build_threaded_runtime(),
}
}
- fn build_shell_runtime(&mut self) -> io::Result<Runtime> {
- use crate::runtime::Kind;
+ fn get_cfg(&self) -> driver::Cfg {
+ driver::Cfg {
+ enable_io: self.enable_io,
+ enable_time: self.enable_time,
+ }
+ }
- let clock = time::create_clock();
+ /// Sets a custom timeout for a thread in the blocking pool.
+ ///
+ /// By default, the timeout for a thread is set to 10 seconds. This can
+ /// be overriden using .thread_keep_alive().
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// # use tokio::runtime;
+ /// # use std::time::Duration;
+ ///
+ /// # pub fn main() {
+ /// let rt = runtime::Builder::new_multi_thread()
+ /// .thread_keep_alive(Duration::from_millis(100))
+ /// .build();
+ /// # }
+ /// ```
+ pub fn thread_keep_alive(&mut self, duration: Duration) -> &mut Self {
+ self.keep_alive = Some(duration);
+ self
+ }
+
+ fn build_basic_runtime(&mut self) -> io::Result<Runtime> {
+ use crate::runtime::{BasicScheduler, Kind};
- // Create I/O driver
- let (io_driver, io_handle) = io::create_driver(self.enable_io)?;
- let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone());
+ let (driver, resources) = driver::Driver::new(self.get_cfg())?;
- let spawner = Spawner::Shell;
+ // And now put a single-threaded scheduler on top of the timer. When
+ // there are no futures ready to do something, it'll let the timer or
+ // the reactor to generate some new stimuli for the futures to continue
+ // in their life.
+ let scheduler = BasicScheduler::new(driver);
+ let spawner = Spawner::Basic(scheduler.spawner().clone());
+ // Blocking pool
let blocking_pool = blocking::create_blocking_pool(self, self.max_threads);
let blocking_spawner = blocking_pool.spawner().clone();
Ok(Runtime {
- kind: Kind::Shell(Shell::new(driver)),
+ kind: Kind::CurrentThread(scheduler),
handle: Handle {
spawner,
- io_handle,
- time_handle,
- clock,
+ io_handle: resources.io_handle,
+ time_handle: resources.time_handle,
+ signal_handle: resources.signal_handle,
+ clock: resources.clock,
blocking_spawner,
},
blocking_pool,
@@ -359,7 +449,7 @@ cfg_io_driver! {
/// ```
/// use tokio::runtime;
///
- /// let rt = runtime::Builder::new()
+ /// let rt = runtime::Builder::new_multi_thread()
/// .enable_io()
/// .build()
/// .unwrap();
@@ -382,7 +472,7 @@ cfg_time! {
/// ```
/// use tokio::runtime;
///
- /// let rt = runtime::Builder::new()
+ /// let rt = runtime::Builder::new_multi_thread()
/// .enable_time()
/// .build()
/// .unwrap();
@@ -394,85 +484,19 @@ cfg_time! {
}
}
-cfg_rt_core! {
+cfg_rt_multi_thread! {
impl Builder {
- /// Sets runtime to use a simpler scheduler that runs all tasks on the current-thread.
- ///
- /// The executor and all necessary drivers will all be run on the current
- /// thread during [`block_on`] calls.
- ///
- /// See also [the module level documentation][1], which has a section on scheduler
- /// types.
- ///
- /// [1]: index.html#runtime-configurations
- /// [`block_on`]: Runtime::block_on
- pub fn basic_scheduler(&mut self) -> &mut Self {
- self.kind = Kind::Basic;
- self
- }
-
- fn build_basic_runtime(&mut self) -> io::Result<Runtime> {
- use crate::runtime::{BasicScheduler, Kind};
-
- let clock = time::create_clock();
-
- // Create I/O driver
- let (io_driver, io_handle) = io::create_driver(self.enable_io)?;
-
- let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone());
-
- // And now put a single-threaded scheduler on top of the timer. When
- // there are no futures ready to do something, it'll let the timer or
- // the reactor to generate some new stimuli for the futures to continue
- // in their life.
- let scheduler = BasicScheduler::new(driver);
- let spawner = Spawner::Basic(scheduler.spawner().clone());
-
- // Blocking pool
- let blocking_pool = blocking::create_blocking_pool(self, self.max_threads);
- let blocking_spawner = blocking_pool.spawner().clone();
-
- Ok(Runtime {
- kind: Kind::Basic(scheduler),
- handle: Handle {
- spawner,
- io_handle,
- time_handle,
- clock,
- blocking_spawner,
- },
- blocking_pool,
- })
- }
- }
-}
-
-cfg_rt_threaded! {
- impl Builder {
- /// Sets runtime to use a multi-threaded scheduler for executing tasks.
- ///
- /// See also [the module level documentation][1], which has a section on scheduler
- /// types.
- ///
- /// [1]: index.html#runtime-configurations
- pub fn threaded_scheduler(&mut self) -> &mut Self {
- self.kind = Kind::ThreadPool;
- self
- }
-
fn build_threaded_runtime(&mut self) -> io::Result<Runtime> {
use crate::loom::sys::num_cpus;
use crate::runtime::{Kind, ThreadPool};
use crate::runtime::park::Parker;
use std::cmp;
- let core_threads = self.core_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus()));
+ let core_threads = self.worker_threads.unwrap_or_else(|| cmp::min(self.max_threads, num_cpus()));
assert!(core_threads <= self.max_threads, "Core threads number cannot be above max limit");
- let clock = time::create_clock();
+ let (driver, resources) = driver::Driver::new(self.get_cfg())?;
- let (io_driver, io_handle) = io::create_driver(self.enable_io)?;
- let (driver, time_handle) = time::create_driver(self.enable_time, io_driver, clock.clone());
let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver));
let spawner = Spawner::ThreadPool(scheduler.spawner().clone());
@@ -483,14 +507,16 @@ cfg_rt_threaded! {
// Create the runtime handle
let handle = Handle {
spawner,
- io_handle,
- time_handle,
- clock,
+ io_handle: resources.io_handle,
+ time_handle: resources.time_handle,
+ signal_handle: resources.signal_handle,
+ clock: resources.clock,
blocking_spawner,
};
// Spawn the thread pool workers
- handle.enter(|| launch.launch());
+ let _enter = crate::runtime::context::enter(handle.clone());
+ launch.launch();
Ok(Runtime {
kind: Kind::ThreadPool(scheduler),
@@ -501,19 +527,15 @@ cfg_rt_threaded! {
}
}
-impl Default for Builder {
- fn default() -> Self {
- Self::new()
- }
-}
-
impl fmt::Debug for Builder {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Builder")
- .field("kind", &self.kind)
- .field("core_threads", &self.core_threads)
+ .field("worker_threads", &self.worker_threads)
.field("max_threads", &self.max_threads)
- .field("thread_name", &self.thread_name)
+ .field(
+ "thread_name",
+ &"<dyn Fn() -> String + Send + Sync + 'static>",
+ )
.field("thread_stack_size", &self.thread_stack_size)
.field("after_start", &self.after_start.as_ref().map(|_| "..."))
.field("before_stop", &self.after_start.as_ref().map(|_| "..."))
diff --git a/src/runtime/context.rs b/src/runtime/context.rs
index 1b267f4..0817019 100644
--- a/src/runtime/context.rs
+++ b/src/runtime/context.rs
@@ -12,7 +12,7 @@ pub(crate) fn current() -> Option<Handle> {
}
cfg_io_driver! {
- pub(crate) fn io_handle() -> crate::runtime::io::Handle {
+ pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle {
CONTEXT.with(|ctx| match *ctx.borrow() {
Some(ref ctx) => ctx.io_handle.clone(),
None => Default::default(),
@@ -20,8 +20,18 @@ cfg_io_driver! {
}
}
+cfg_signal_internal! {
+ #[cfg(unix)]
+ pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle {
+ CONTEXT.with(|ctx| match *ctx.borrow() {
+ Some(ref ctx) => ctx.signal_handle.clone(),
+ None => Default::default(),
+ })
+ }
+}
+
cfg_time! {
- pub(crate) fn time_handle() -> crate::runtime::time::Handle {
+ pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle {
CONTEXT.with(|ctx| match *ctx.borrow() {
Some(ref ctx) => ctx.time_handle.clone(),
None => Default::default(),
@@ -29,7 +39,7 @@ cfg_time! {
}
cfg_test_util! {
- pub(crate) fn clock() -> Option<crate::runtime::time::Clock> {
+ pub(crate) fn clock() -> Option<crate::runtime::driver::Clock> {
CONTEXT.with(|ctx| match *ctx.borrow() {
Some(ref ctx) => Some(ctx.clock.clone()),
None => None,
@@ -38,7 +48,7 @@ cfg_time! {
}
}
-cfg_rt_core! {
+cfg_rt! {
pub(crate) fn spawn_handle() -> Option<crate::runtime::Spawner> {
CONTEXT.with(|ctx| match *ctx.borrow() {
Some(ref ctx) => Some(ctx.spawner.clone()),
@@ -50,24 +60,20 @@ cfg_rt_core! {
/// Set this [`Handle`] as the current active [`Handle`].
///
/// [`Handle`]: Handle
-pub(crate) fn enter<F, R>(new: Handle, f: F) -> R
-where
- F: FnOnce() -> R,
-{
- struct DropGuard(Option<Handle>);
-
- impl Drop for DropGuard {
- fn drop(&mut self) {
- CONTEXT.with(|ctx| {
- *ctx.borrow_mut() = self.0.take();
- });
- }
- }
-
- let _guard = CONTEXT.with(|ctx| {
+pub(crate) fn enter(new: Handle) -> EnterGuard {
+ CONTEXT.with(|ctx| {
let old = ctx.borrow_mut().replace(new);
- DropGuard(old)
- });
+ EnterGuard(old)
+ })
+}
+
+#[derive(Debug)]
+pub(crate) struct EnterGuard(Option<Handle>);
- f()
+impl Drop for EnterGuard {
+ fn drop(&mut self) {
+ CONTEXT.with(|ctx| {
+ *ctx.borrow_mut() = self.0.take();
+ });
+ }
}
diff --git a/src/runtime/driver.rs b/src/runtime/driver.rs
new file mode 100644
index 0000000..e89de9d
--- /dev/null
+++ b/src/runtime/driver.rs
@@ -0,0 +1,205 @@
+//! Abstracts out the entire chain of runtime sub-drivers into common types.
+use crate::park::thread::ParkThread;
+use crate::park::Park;
+
+use std::io;
+use std::time::Duration;
+
+// ===== io driver =====
+
+cfg_io_driver! {
+ type IoDriver = crate::io::driver::Driver;
+ type IoStack = crate::park::either::Either<ProcessDriver, ParkThread>;
+ pub(crate) type IoHandle = Option<crate::io::driver::Handle>;
+
+ fn create_io_stack(enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
+ use crate::park::either::Either;
+
+ #[cfg(loom)]
+ assert!(!enabled);
+
+ let ret = if enabled {
+ let io_driver = crate::io::driver::Driver::new()?;
+ let io_handle = io_driver.handle();
+
+ let (signal_driver, signal_handle) = create_signal_driver(io_driver)?;
+ let process_driver = create_process_driver(signal_driver)?;
+
+ (Either::A(process_driver), Some(io_handle), signal_handle)
+ } else {
+ (Either::B(ParkThread::new()), Default::default(), Default::default())
+ };
+
+ Ok(ret)
+ }
+}
+
+cfg_not_io_driver! {
+ pub(crate) type IoHandle = ();
+ type IoStack = ParkThread;
+
+ fn create_io_stack(_enabled: bool) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
+ Ok((ParkThread::new(), Default::default(), Default::default()))
+ }
+}
+
+// ===== signal driver =====
+
+macro_rules! cfg_signal_internal_and_unix {
+ ($($item:item)*) => {
+ #[cfg(unix)]
+ cfg_signal_internal! { $($item)* }
+ }
+}
+
+cfg_signal_internal_and_unix! {
+ type SignalDriver = crate::signal::unix::driver::Driver;
+ pub(crate) type SignalHandle = Option<crate::signal::unix::driver::Handle>;
+
+ fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> {
+ let driver = crate::signal::unix::driver::Driver::new(io_driver)?;
+ let handle = driver.handle();
+ Ok((driver, Some(handle)))
+ }
+}
+
+cfg_not_signal_internal! {
+ pub(crate) type SignalHandle = ();
+
+ cfg_io_driver! {
+ type SignalDriver = IoDriver;
+
+ fn create_signal_driver(io_driver: IoDriver) -> io::Result<(SignalDriver, SignalHandle)> {
+ Ok((io_driver, ()))
+ }
+ }
+}
+
+// ===== process driver =====
+
+cfg_process_driver! {
+ type ProcessDriver = crate::process::unix::driver::Driver;
+
+ fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> {
+ crate::process::unix::driver::Driver::new(signal_driver)
+ }
+}
+
+cfg_not_process_driver! {
+ cfg_io_driver! {
+ type ProcessDriver = SignalDriver;
+
+ fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> {
+ Ok(signal_driver)
+ }
+ }
+}
+
+// ===== time driver =====
+
+cfg_time! {
+ type TimeDriver = crate::park::either::Either<crate::time::driver::Driver<IoStack>, IoStack>;
+
+ pub(crate) type Clock = crate::time::Clock;
+ pub(crate) type TimeHandle = Option<crate::time::driver::Handle>;
+
+ fn create_clock() -> Clock {
+ crate::time::Clock::new()
+ }
+
+ fn create_time_driver(
+ enable: bool,
+ io_stack: IoStack,
+ clock: Clock,
+ ) -> (TimeDriver, TimeHandle) {
+ use crate::park::either::Either;
+
+ if enable {
+ let driver = crate::time::driver::Driver::new(io_stack, clock);
+ let handle = driver.handle();
+
+ (Either::A(driver), Some(handle))
+ } else {
+ (Either::B(io_stack), None)
+ }
+ }
+}
+
+cfg_not_time! {
+ type TimeDriver = IoStack;
+
+ pub(crate) type Clock = ();
+ pub(crate) type TimeHandle = ();
+
+ fn create_clock() -> Clock {
+ ()
+ }
+
+ fn create_time_driver(
+ _enable: bool,
+ io_stack: IoStack,
+ _clock: Clock,
+ ) -> (TimeDriver, TimeHandle) {
+ (io_stack, ())
+ }
+}
+
+// ===== runtime driver =====
+
+#[derive(Debug)]
+pub(crate) struct Driver {
+ inner: TimeDriver,
+}
+
+pub(crate) struct Resources {
+ pub(crate) io_handle: IoHandle,
+ pub(crate) signal_handle: SignalHandle,
+ pub(crate) time_handle: TimeHandle,
+ pub(crate) clock: Clock,
+}
+
+pub(crate) struct Cfg {
+ pub(crate) enable_io: bool,
+ pub(crate) enable_time: bool,
+}
+
+impl Driver {
+ pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Resources)> {
+ let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io)?;
+
+ let clock = create_clock();
+ let (time_driver, time_handle) =
+ create_time_driver(cfg.enable_time, io_stack, clock.clone());
+
+ Ok((
+ Self { inner: time_driver },
+ Resources {
+ io_handle,
+ signal_handle,
+ time_handle,
+ clock,
+ },
+ ))
+ }
+}
+
+impl Park for Driver {
+ type Unpark = <TimeDriver as Park>::Unpark;
+ type Error = <TimeDriver as Park>::Error;
+
+ fn unpark(&self) -> Self::Unpark {
+ self.inner.unpark()
+ }
+
+ fn park(&mut self) -> Result<(), Self::Error> {
+ self.inner.park()
+ }
+
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
+ self.inner.park_timeout(duration)
+ }
+
+ fn shutdown(&mut self) {
+ self.inner.shutdown()
+ }
+}
diff --git a/src/runtime/enter.rs b/src/runtime/enter.rs
index 56a7c57..4dd8dd0 100644
--- a/src/runtime/enter.rs
+++ b/src/runtime/enter.rs
@@ -4,8 +4,8 @@ use std::marker::PhantomData;
#[derive(Debug, Clone, Copy)]
pub(crate) enum EnterContext {
+ #[cfg_attr(not(feature = "rt"), allow(dead_code))]
Entered {
- #[allow(dead_code)]
allow_blocking: bool,
},
NotEntered,
@@ -13,11 +13,7 @@ pub(crate) enum EnterContext {
impl EnterContext {
pub(crate) fn is_entered(self) -> bool {
- if let EnterContext::Entered { .. } = self {
- true
- } else {
- false
- }
+ matches!(self, EnterContext::Entered { .. })
}
}
@@ -28,32 +24,38 @@ pub(crate) struct Enter {
_p: PhantomData<RefCell<()>>,
}
-/// Marks the current thread as being within the dynamic extent of an
-/// executor.
-pub(crate) fn enter(allow_blocking: bool) -> Enter {
- if let Some(enter) = try_enter(allow_blocking) {
- return enter;
- }
+cfg_rt! {
+ use crate::park::thread::ParkError;
- panic!(
- "Cannot start a runtime from within a runtime. This happens \
- because a function (like `block_on`) attempted to block the \
- current thread while the thread is being used to drive \
- asynchronous tasks."
- );
-}
+ use std::time::Duration;
-/// Tries to enter a runtime context, returns `None` if already in a runtime
-/// context.
-pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> {
- ENTERED.with(|c| {
- if c.get().is_entered() {
- None
- } else {
- c.set(EnterContext::Entered { allow_blocking });
- Some(Enter { _p: PhantomData })
+ /// Marks the current thread as being within the dynamic extent of an
+ /// executor.
+ pub(crate) fn enter(allow_blocking: bool) -> Enter {
+ if let Some(enter) = try_enter(allow_blocking) {
+ return enter;
}
- })
+
+ panic!(
+ "Cannot start a runtime from within a runtime. This happens \
+ because a function (like `block_on`) attempted to block the \
+ current thread while the thread is being used to drive \
+ asynchronous tasks."
+ );
+ }
+
+ /// Tries to enter a runtime context, returns `None` if already in a runtime
+ /// context.
+ pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> {
+ ENTERED.with(|c| {
+ if c.get().is_entered() {
+ None
+ } else {
+ c.set(EnterContext::Entered { allow_blocking });
+ Some(Enter { _p: PhantomData })
+ }
+ })
+ }
}
// Forces the current "entered" state to be cleared while the closure
@@ -63,115 +65,92 @@ pub(crate) fn try_enter(allow_blocking: bool) -> Option<Enter> {
//
// This is hidden for a reason. Do not use without fully understanding
// executors. Misuing can easily cause your program to deadlock.
-#[cfg(all(feature = "rt-threaded", feature = "blocking"))]
-pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
- // Reset in case the closure panics
- struct Reset(EnterContext);
- impl Drop for Reset {
- fn drop(&mut self) {
- ENTERED.with(|c| {
- assert!(!c.get().is_entered(), "closure claimed permanent executor");
- c.set(self.0);
- });
+cfg_rt_multi_thread! {
+ pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
+ // Reset in case the closure panics
+ struct Reset(EnterContext);
+ impl Drop for Reset {
+ fn drop(&mut self) {
+ ENTERED.with(|c| {
+ assert!(!c.get().is_entered(), "closure claimed permanent executor");
+ c.set(self.0);
+ });
+ }
}
- }
- let was = ENTERED.with(|c| {
- let e = c.get();
- assert!(e.is_entered(), "asked to exit when not entered");
- c.set(EnterContext::NotEntered);
- e
- });
+ let was = ENTERED.with(|c| {
+ let e = c.get();
+ assert!(e.is_entered(), "asked to exit when not entered");
+ c.set(EnterContext::NotEntered);
+ e
+ });
- let _reset = Reset(was);
- // dropping _reset after f() will reset ENTERED
- f()
+ let _reset = Reset(was);
+ // dropping _reset after f() will reset ENTERED
+ f()
+ }
}
-cfg_rt_core! {
- cfg_rt_util! {
- /// Disallow blocking in the current runtime context until the guard is dropped.
- pub(crate) fn disallow_blocking() -> DisallowBlockingGuard {
- let reset = ENTERED.with(|c| {
- if let EnterContext::Entered {
- allow_blocking: true,
- } = c.get()
- {
- c.set(EnterContext::Entered {
- allow_blocking: false,
- });
- true
- } else {
- false
- }
- });
- DisallowBlockingGuard(reset)
- }
+cfg_rt! {
+ /// Disallow blocking in the current runtime context until the guard is dropped.
+ pub(crate) fn disallow_blocking() -> DisallowBlockingGuard {
+ let reset = ENTERED.with(|c| {
+ if let EnterContext::Entered {
+ allow_blocking: true,
+ } = c.get()
+ {
+ c.set(EnterContext::Entered {
+ allow_blocking: false,
+ });
+ true
+ } else {
+ false
+ }
+ });
+ DisallowBlockingGuard(reset)
+ }
- pub(crate) struct DisallowBlockingGuard(bool);
- impl Drop for DisallowBlockingGuard {
- fn drop(&mut self) {
- if self.0 {
- // XXX: Do we want some kind of assertion here, or is "best effort" okay?
- ENTERED.with(|c| {
- if let EnterContext::Entered {
- allow_blocking: false,
- } = c.get()
- {
- c.set(EnterContext::Entered {
- allow_blocking: true,
- });
- }
- })
- }
+ pub(crate) struct DisallowBlockingGuard(bool);
+ impl Drop for DisallowBlockingGuard {
+ fn drop(&mut self) {
+ if self.0 {
+ // XXX: Do we want some kind of assertion here, or is "best effort" okay?
+ ENTERED.with(|c| {
+ if let EnterContext::Entered {
+ allow_blocking: false,
+ } = c.get()
+ {
+ c.set(EnterContext::Entered {
+ allow_blocking: true,
+ });
+ }
+ })
}
}
}
}
-cfg_rt_threaded! {
- cfg_blocking! {
- /// Returns true if in a runtime context.
- pub(crate) fn context() -> EnterContext {
- ENTERED.with(|c| c.get())
- }
+cfg_rt_multi_thread! {
+ /// Returns true if in a runtime context.
+ pub(crate) fn context() -> EnterContext {
+ ENTERED.with(|c| c.get())
}
}
-cfg_block_on! {
+cfg_rt! {
impl Enter {
/// Blocks the thread on the specified future, returning the value with
/// which that future completes.
- pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, crate::park::ParkError>
+ pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, ParkError>
where
F: std::future::Future,
{
- use crate::park::{CachedParkThread, Park};
- use std::task::Context;
- use std::task::Poll::Ready;
+ use crate::park::thread::CachedParkThread;
let mut park = CachedParkThread::new();
- let waker = park.get_unpark()?.into_waker();
- let mut cx = Context::from_waker(&waker);
-
- pin!(f);
-
- loop {
- if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
- return Ok(v);
- }
-
- park.park()?;
- }
+ park.block_on(f)
}
- }
-}
-cfg_blocking_impl! {
- use crate::park::ParkError;
- use std::time::Duration;
-
- impl Enter {
/// Blocks the thread on the specified future for **at most** `timeout`
///
/// If the future completes before `timeout`, the result is returned. If
@@ -180,7 +159,8 @@ cfg_blocking_impl! {
where
F: std::future::Future,
{
- use crate::park::{CachedParkThread, Park};
+ use crate::park::Park;
+ use crate::park::thread::CachedParkThread;
use std::task::Context;
use std::task::Poll::Ready;
use std::time::Instant;
diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs
index 0716a7f..b1e8d8f 100644
--- a/src/runtime/handle.rs
+++ b/src/runtime/handle.rs
@@ -1,16 +1,6 @@
-use crate::runtime::{blocking, context, io, time, Spawner};
-use std::{error, fmt};
-
-cfg_blocking! {
- use crate::runtime::task;
- use crate::runtime::blocking::task::BlockingTask;
-}
-
-cfg_rt_core! {
- use crate::task::JoinHandle;
-
- use std::future::Future;
-}
+use crate::runtime::blocking::task::BlockingTask;
+use crate::runtime::task::{self, JoinHandle};
+use crate::runtime::{blocking, driver, Spawner};
/// Handle to the runtime.
///
@@ -19,353 +9,56 @@ cfg_rt_core! {
///
/// [`Runtime::handle`]: crate::runtime::Runtime::handle()
#[derive(Debug, Clone)]
-pub struct Handle {
+pub(crate) struct Handle {
pub(super) spawner: Spawner,
/// Handles to the I/O drivers
- pub(super) io_handle: io::Handle,
+ pub(super) io_handle: driver::IoHandle,
+
+ /// Handles to the signal drivers
+ pub(super) signal_handle: driver::SignalHandle,
/// Handles to the time drivers
- pub(super) time_handle: time::Handle,
+ pub(super) time_handle: driver::TimeHandle,
/// Source of `Instant::now()`
- pub(super) clock: time::Clock,
+ pub(super) clock: driver::Clock,
/// Blocking pool spawner
pub(super) blocking_spawner: blocking::Spawner,
}
impl Handle {
- /// Enter the runtime context. This allows you to construct types that must
- /// have an executor available on creation such as [`Delay`] or [`TcpStream`].
- /// It will also allow you to call methods such as [`tokio::spawn`].
- ///
- /// This function is also available as [`Runtime::enter`].
- ///
- /// [`Delay`]: struct@crate::time::Delay
- /// [`TcpStream`]: struct@crate::net::TcpStream
- /// [`Runtime::enter`]: fn@crate::runtime::Runtime::enter
- /// [`tokio::spawn`]: fn@crate::spawn
- ///
- /// # Example
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// fn function_that_spawns(msg: String) {
- /// // Had we not used `handle.enter` below, this would panic.
- /// tokio::spawn(async move {
- /// println!("{}", msg);
- /// });
- /// }
- ///
- /// fn main() {
- /// let rt = Runtime::new().unwrap();
- /// let handle = rt.handle().clone();
- ///
- /// let s = "Hello World!".to_string();
- ///
- /// // By entering the context, we tie `tokio::spawn` to this executor.
- /// handle.enter(|| function_that_spawns(s));
- /// }
- /// ```
- pub fn enter<F, R>(&self, f: F) -> R
+ // /// Enter the runtime context. This allows you to construct types that must
+ // /// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
+ // /// It will also allow you to call methods such as [`tokio::spawn`].
+ // pub(crate) fn enter<F, R>(&self, f: F) -> R
+ // where
+ // F: FnOnce() -> R,
+ // {
+ // context::enter(self.clone(), f)
+ // }
+
+ /// Run the provided function on an executor dedicated to blocking operations.
+ pub(crate) fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
- F: FnOnce() -> R,
+ F: FnOnce() -> R + Send + 'static,
{
- context::enter(self.clone(), f)
- }
-
- /// Returns a `Handle` view over the currently running `Runtime`
- ///
- /// # Panic
- ///
- /// This will panic if called outside the context of a Tokio runtime. That means that you must
- /// call this on one of the threads **being run by the runtime**. Calling this from within a
- /// thread created by `std::thread::spawn` (for example) will cause a panic.
- ///
- /// # Examples
- ///
- /// This can be used to obtain the handle of the surrounding runtime from an async
- /// block or function running on that runtime.
- ///
- /// ```
- /// # use std::thread;
- /// # use tokio::runtime::Runtime;
- /// # fn dox() {
- /// # let rt = Runtime::new().unwrap();
- /// # rt.spawn(async {
- /// use tokio::runtime::Handle;
- ///
- /// // Inside an async block or function.
- /// let handle = Handle::current();
- /// handle.spawn(async {
- /// println!("now running in the existing Runtime");
- /// });
- ///
- /// # let handle =
- /// thread::spawn(move || {
- /// // Notice that the handle is created outside of this thread and then moved in
- /// handle.block_on(async { /* ... */ })
- /// // This next line would cause a panic
- /// // let handle2 = Handle::current();
- /// });
- /// # handle.join().unwrap();
- /// # });
- /// # }
- /// ```
- pub fn current() -> Self {
- context::current().expect("not currently running on the Tokio runtime.")
- }
-
- /// Returns a Handle view over the currently running Runtime
- ///
- /// Returns an error if no Runtime has been started
- ///
- /// Contrary to `current`, this never panics
- pub fn try_current() -> Result<Self, TryCurrentError> {
- context::current().ok_or(TryCurrentError(()))
+ #[cfg(feature = "tracing")]
+ let func = {
+ let span = tracing::trace_span!(
+ target: "tokio::task",
+ "task",
+ kind = %"blocking",
+ function = %std::any::type_name::<F>(),
+ );
+ move || {
+ let _g = span.enter();
+ func()
+ }
+ };
+ let (task, handle) = task::joinable(BlockingTask::new(func));
+ let _ = self.blocking_spawner.spawn(task, &self);
+ handle
}
}
-
-cfg_rt_core! {
- impl Handle {
- /// Spawns a future onto the Tokio runtime.
- ///
- /// This spawns the given future onto the runtime's executor, usually a
- /// thread pool. The thread pool is then responsible for polling the future
- /// until it completes.
- ///
- /// See [module level][mod] documentation for more details.
- ///
- /// [mod]: index.html
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// # fn dox() {
- /// // Create the runtime
- /// let rt = Runtime::new().unwrap();
- /// let handle = rt.handle();
- ///
- /// // Spawn a future onto the runtime
- /// handle.spawn(async {
- /// println!("now running on a worker thread");
- /// });
- /// # }
- /// ```
- ///
- /// # Panics
- ///
- /// This function will not panic unless task execution is disabled on the
- /// executor. This can only happen if the runtime was built using
- /// [`Builder`] without picking either [`basic_scheduler`] or
- /// [`threaded_scheduler`].
- ///
- /// [`Builder`]: struct@crate::runtime::Builder
- /// [`threaded_scheduler`]: fn@crate::runtime::Builder::threaded_scheduler
- /// [`basic_scheduler`]: fn@crate::runtime::Builder::basic_scheduler
- pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
- where
- F: Future + Send + 'static,
- F::Output: Send + 'static,
- {
- self.spawner.spawn(future)
- }
-
- /// Run a future to completion on the Tokio runtime from a synchronous
- /// context.
- ///
- /// This runs the given future on the runtime, blocking until it is
- /// complete, and yielding its resolved result. Any tasks or timers which
- /// the future spawns internally will be executed on the runtime.
- ///
- /// If the provided executor currently has no active core thread, this
- /// function might hang until a core thread is added. This is not a
- /// concern when using the [threaded scheduler], as it always has active
- /// core threads, but if you use the [basic scheduler], some other
- /// thread must currently be inside a call to [`Runtime::block_on`].
- /// See also [the module level documentation][1], which has a section on
- /// scheduler types.
- ///
- /// This method may not be called from an asynchronous context.
- ///
- /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler
- /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler
- /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on
- /// [1]: index.html#runtime-configurations
- ///
- /// # Panics
- ///
- /// This function panics if the provided future panics, or if called
- /// within an asynchronous execution context.
- ///
- /// # Examples
- ///
- /// Using `block_on` with the [threaded scheduler].
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- /// use std::thread;
- ///
- /// // Create the runtime.
- /// //
- /// // If the rt-threaded feature is enabled, this creates a threaded
- /// // scheduler by default.
- /// let rt = Runtime::new().unwrap();
- /// let handle = rt.handle().clone();
- ///
- /// // Use the runtime from another thread.
- /// let th = thread::spawn(move || {
- /// // Execute the future, blocking the current thread until completion.
- /// //
- /// // This example uses the threaded scheduler, so no concurrent call to
- /// // `rt.block_on` is required.
- /// handle.block_on(async {
- /// println!("hello");
- /// });
- /// });
- ///
- /// th.join().unwrap();
- /// ```
- ///
- /// Using the [basic scheduler] requires a concurrent call to
- /// [`Runtime::block_on`]:
- ///
- /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler
- /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler
- /// [`Runtime::block_on`]: fn@crate::runtime::Runtime::block_on
- ///
- /// ```
- /// use tokio::runtime::Builder;
- /// use tokio::sync::oneshot;
- /// use std::thread;
- ///
- /// // Create the runtime.
- /// let mut rt = Builder::new()
- /// .enable_all()
- /// .basic_scheduler()
- /// .build()
- /// .unwrap();
- ///
- /// let handle = rt.handle().clone();
- ///
- /// // Signal main thread when task has finished.
- /// let (send, recv) = oneshot::channel();
- ///
- /// // Use the runtime from another thread.
- /// let th = thread::spawn(move || {
- /// // Execute the future, blocking the current thread until completion.
- /// handle.block_on(async {
- /// send.send("done").unwrap();
- /// });
- /// });
- ///
- /// // The basic scheduler is used, so the thread above might hang if we
- /// // didn't call block_on on the rt too.
- /// rt.block_on(async {
- /// assert_eq!(recv.await.unwrap(), "done");
- /// });
- /// # th.join().unwrap();
- /// ```
- ///
- pub fn block_on<F: Future>(&self, future: F) -> F::Output {
- self.enter(|| {
- let mut enter = crate::runtime::enter(true);
- enter.block_on(future).expect("failed to park thread")
- })
- }
- }
-}
-
-cfg_blocking! {
- impl Handle {
- /// Runs the provided closure on a thread where blocking is acceptable.
- ///
- /// In general, issuing a blocking call or performing a lot of compute in a
- /// future without yielding is not okay, as it may prevent the executor from
- /// driving other futures forward. This function runs the provided closure
- /// on a thread dedicated to blocking operations. See the [CPU-bound tasks
- /// and blocking code][blocking] section for more information.
- ///
- /// Tokio will spawn more blocking threads when they are requested through
- /// this function until the upper limit configured on the [`Builder`] is
- /// reached. This limit is very large by default, because `spawn_blocking` is
- /// often used for various kinds of IO operations that cannot be performed
- /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you
- /// should keep this large upper limit in mind; to run your CPU-bound
- /// computations on only a few threads, you should use a separate thread
- /// pool such as [rayon] rather than configuring the number of blocking
- /// threads.
- ///
- /// This function is intended for non-async operations that eventually
- /// finish on their own. If you want to spawn an ordinary thread, you should
- /// use [`thread::spawn`] instead.
- ///
- /// Closures spawned using `spawn_blocking` cannot be cancelled. When you
- /// shut down the executor, it will wait indefinitely for all blocking
- /// operations to finish. You can use [`shutdown_timeout`] to stop waiting
- /// for them after a certain timeout. Be aware that this will still not
- /// cancel the tasks — they are simply allowed to keep running after the
- /// method returns.
- ///
- /// Note that if you are using the [basic scheduler], this function will
- /// still spawn additional threads for blocking operations. The basic
- /// scheduler's single thread is only used for asynchronous code.
- ///
- /// [`Builder`]: struct@crate::runtime::Builder
- /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code
- /// [rayon]: https://docs.rs/rayon
- /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler
- /// [`thread::spawn`]: fn@std::thread::spawn
- /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{
- /// // Create the runtime
- /// let rt = Runtime::new().unwrap();
- /// let handle = rt.handle();
- ///
- /// let res = handle.spawn_blocking(move || {
- /// // do some compute-heavy work or call synchronous code
- /// "done computing"
- /// }).await?;
- ///
- /// assert_eq!(res, "done computing");
- /// # Ok(())
- /// # }
- /// ```
- pub fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
- where
- F: FnOnce() -> R + Send + 'static,
- R: Send + 'static,
- {
- let (task, handle) = task::joinable(BlockingTask::new(f));
- let _ = self.blocking_spawner.spawn(task, self);
- handle
- }
- }
-}
-
-/// Error returned by `try_current` when no Runtime has been started
-pub struct TryCurrentError(());
-
-impl fmt::Debug for TryCurrentError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("TryCurrentError").finish()
- }
-}
-
-impl fmt::Display for TryCurrentError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.write_str("no tokio Runtime has been initialized")
- }
-}
-
-impl error::Error for TryCurrentError {}
diff --git a/src/runtime/io.rs b/src/runtime/io.rs
deleted file mode 100644
index 6a0953a..0000000
--- a/src/runtime/io.rs
+++ /dev/null
@@ -1,63 +0,0 @@
-//! Abstracts out the APIs necessary to `Runtime` for integrating the I/O
-//! driver. When the `time` feature flag is **not** enabled. These APIs are
-//! shells. This isolates the complexity of dealing with conditional
-//! compilation.
-
-/// Re-exported for convenience.
-pub(crate) use std::io::Result;
-
-pub(crate) use variant::*;
-
-#[cfg(feature = "io-driver")]
-mod variant {
- use crate::io::driver;
- use crate::park::{Either, ParkThread};
-
- use std::io;
-
- /// The driver value the runtime passes to the `timer` layer.
- ///
- /// When the `io-driver` feature is enabled, this is the "real" I/O driver
- /// backed by Mio. Without the `io-driver` feature, this is a thread parker
- /// backed by a condition variable.
- pub(crate) type Driver = Either<driver::Driver, ParkThread>;
-
- /// The handle the runtime stores for future use.
- ///
- /// When the `io-driver` feature is **not** enabled, this is `()`.
- pub(crate) type Handle = Option<driver::Handle>;
-
- pub(crate) fn create_driver(enable: bool) -> io::Result<(Driver, Handle)> {
- #[cfg(loom)]
- assert!(!enable);
-
- if enable {
- let driver = driver::Driver::new()?;
- let handle = driver.handle();
-
- Ok((Either::A(driver), Some(handle)))
- } else {
- let driver = ParkThread::new();
- Ok((Either::B(driver), None))
- }
- }
-}
-
-#[cfg(not(feature = "io-driver"))]
-mod variant {
- use crate::park::ParkThread;
-
- use std::io;
-
- /// I/O is not enabled, use a condition variable based parker
- pub(crate) type Driver = ParkThread;
-
- /// There is no handle
- pub(crate) type Handle = ();
-
- pub(crate) fn create_driver(_enable: bool) -> io::Result<(Driver, Handle)> {
- let driver = ParkThread::new();
-
- Ok((driver, ()))
- }
-}
diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs
index 300a146..be4aa38 100644
--- a/src/runtime/mod.rs
+++ b/src/runtime/mod.rs
@@ -1,8 +1,7 @@
//! The Tokio runtime.
//!
-//! Unlike other Rust programs, asynchronous applications require
-//! runtime support. In particular, the following runtime services are
-//! necessary:
+//! Unlike other Rust programs, asynchronous applications require runtime
+//! support. In particular, the following runtime services are necessary:
//!
//! * An **I/O event loop**, called the driver, which drives I/O resources and
//! dispatches I/O events to tasks that depend on them.
@@ -10,14 +9,14 @@
//! * A **timer** for scheduling work to run after a set period of time.
//!
//! Tokio's [`Runtime`] bundles all of these services as a single type, allowing
-//! them to be started, shut down, and configured together. However, most
-//! applications won't need to use [`Runtime`] directly. Instead, they can
-//! use the [`tokio::main`] attribute macro, which creates a [`Runtime`] under
-//! the hood.
+//! them to be started, shut down, and configured together. However, often it is
+//! not required to configure a [`Runtime`] manually, and user may just use the
+//! [`tokio::main`] attribute macro, which creates a [`Runtime`] under the hood.
//!
//! # Usage
//!
-//! Most applications will use the [`tokio::main`] attribute macro.
+//! When no fine tuning is required, the [`tokio::main`] attribute macro can be
+//! used.
//!
//! ```no_run
//! use tokio::net::TcpListener;
@@ -25,7 +24,7 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
@@ -69,11 +68,11 @@
//!
//! fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // Create the runtime
-//! let mut rt = Runtime::new()?;
+//! let rt = Runtime::new()?;
//!
//! // Spawn the root task
//! rt.block_on(async {
-//! let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+//! let listener = TcpListener::bind("127.0.0.1:8080").await?;
//!
//! loop {
//! let (mut socket, _) = listener.accept().await?;
@@ -111,48 +110,38 @@
//! applications. The [runtime builder] or `#[tokio::main]` attribute may be
//! used to select which scheduler to use.
//!
-//! #### Basic Scheduler
+//! #### Multi-Thread Scheduler
//!
-//! The basic scheduler provides a _single-threaded_ future executor. All tasks
-//! will be created and executed on the current thread. The basic scheduler
-//! requires the `rt-core` feature flag, and can be selected using the
-//! [`Builder::basic_scheduler`] method:
+//! The multi-thread scheduler executes futures on a _thread pool_, using a
+//! work-stealing strategy. By default, it will start a worker thread for each
+//! CPU core available on the system. This tends to be the ideal configurations
+//! for most applications. The multi-thread scheduler requires the `rt-multi-thread`
+//! feature flag, and is selected by default:
//! ```
//! use tokio::runtime;
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let basic_rt = runtime::Builder::new()
-//! .basic_scheduler()
-//! .build()?;
+//! let threaded_rt = runtime::Runtime::new()?;
//! # Ok(()) }
//! ```
//!
-//! If the `rt-core` feature is enabled and `rt-threaded` is not,
-//! [`Runtime::new`] will return a basic scheduler runtime by default.
+//! Most applications should use the multi-thread scheduler, except in some
+//! niche use-cases, such as when running only a single thread is required.
//!
-//! #### Threaded Scheduler
+//! #### Current-Thread Scheduler
//!
-//! The threaded scheduler executes futures on a _thread pool_, using a
-//! work-stealing strategy. By default, it will start a worker thread for each
-//! CPU core available on the system. This tends to be the ideal configurations
-//! for most applications. The threaded scheduler requires the `rt-threaded` feature
-//! flag, and can be selected using the [`Builder::threaded_scheduler`] method:
+//! The current-thread scheduler provides a _single-threaded_ future executor.
+//! All tasks will be created and executed on the current thread. This requires
+//! the `rt` feature flag.
//! ```
//! use tokio::runtime;
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
-//! let threaded_rt = runtime::Builder::new()
-//! .threaded_scheduler()
+//! let basic_rt = runtime::Builder::new_current_thread()
//! .build()?;
//! # Ok(()) }
//! ```
//!
-//! If the `rt-threaded` feature flag is enabled, [`Runtime::new`] will return a
-//! threaded scheduler runtime by default.
-//!
-//! Most applications should use the threaded scheduler, except in some niche
-//! use-cases, such as when running only a single thread is required.
-//!
//! #### Resource drivers
//!
//! When configuring a runtime by hand, no resource drivers are enabled by
@@ -164,8 +153,8 @@
//! ## Lifetime of spawned threads
//!
//! The runtime may spawn threads depending on its configuration and usage. The
-//! threaded scheduler spawns threads to schedule tasks and calls to
-//! `spawn_blocking` spawn threads to run blocking operations.
+//! multi-thread scheduler spawns threads to schedule tasks and for `spawn_blocking`
+//! calls.
//!
//! While the `Runtime` is active, threads may shutdown after periods of being
//! idle. Once `Runtime` is dropped, all runtime threads are forcibly shutdown.
@@ -188,394 +177,380 @@
#[macro_use]
mod tests;
-pub(crate) mod context;
+pub(crate) mod enter;
+
+pub(crate) mod task;
-cfg_rt_core! {
+cfg_rt! {
mod basic_scheduler;
use basic_scheduler::BasicScheduler;
- pub(crate) mod task;
-}
-
-mod blocking;
-use blocking::BlockingPool;
+ mod blocking;
+ use blocking::BlockingPool;
+ pub(crate) use blocking::spawn_blocking;
-cfg_blocking_impl! {
- #[allow(unused_imports)]
- pub(crate) use blocking::{spawn_blocking, try_spawn_blocking};
-}
+ mod builder;
+ pub use self::builder::Builder;
-mod builder;
-pub use self::builder::Builder;
+ pub(crate) mod context;
+ pub(crate) mod driver;
-pub(crate) mod enter;
-use self::enter::enter;
+ use self::enter::enter;
-mod handle;
-pub use self::handle::{Handle, TryCurrentError};
+ mod handle;
+ use handle::Handle;
-mod io;
+ mod spawner;
+ use self::spawner::Spawner;
+}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
mod park;
use park::Parker;
}
-mod shell;
-use self::shell::Shell;
-
-mod spawner;
-use self::spawner::Spawner;
-
-mod time;
-
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
mod queue;
pub(crate) mod thread_pool;
use self::thread_pool::ThreadPool;
}
-cfg_rt_core! {
+cfg_rt! {
use crate::task::JoinHandle;
-}
-
-use std::future::Future;
-use std::time::Duration;
-
-/// The Tokio runtime.
-///
-/// The runtime provides an I/O driver, task scheduler, [timer], and blocking
-/// pool, necessary for running asynchronous tasks.
-///
-/// Instances of `Runtime` can be created using [`new`] or [`Builder`]. However,
-/// most users will use the `#[tokio::main]` annotation on their entry point instead.
-///
-/// See [module level][mod] documentation for more details.
-///
-/// # Shutdown
-///
-/// Shutting down the runtime is done by dropping the value. The current thread
-/// will block until the shut down operation has completed.
-///
-/// * Drain any scheduled work queues.
-/// * Drop any futures that have not yet completed.
-/// * Drop the reactor.
-///
-/// Once the reactor has dropped, any outstanding I/O resources bound to
-/// that reactor will no longer function. Calling any method on them will
-/// result in an error.
-///
-/// [timer]: crate::time
-/// [mod]: index.html
-/// [`new`]: method@Self::new
-/// [`Builder`]: struct@Builder
-/// [`tokio::run`]: fn@run
-#[derive(Debug)]
-pub struct Runtime {
- /// Task executor
- kind: Kind,
-
- /// Handle to runtime, also contains driver handles
- handle: Handle,
-
- /// Blocking pool handle, used to signal shutdown
- blocking_pool: BlockingPool,
-}
-
-/// The runtime executor is either a thread-pool or a current-thread executor.
-#[derive(Debug)]
-enum Kind {
- /// Not able to execute concurrent tasks. This variant is mostly used to get
- /// access to the driver handles.
- Shell(Shell),
- /// Execute all tasks on the current-thread.
- #[cfg(feature = "rt-core")]
- Basic(BasicScheduler<time::Driver>),
+ use std::future::Future;
+ use std::time::Duration;
- /// Execute tasks across multiple threads.
- #[cfg(feature = "rt-threaded")]
- ThreadPool(ThreadPool),
-}
-
-/// After thread starts / before thread stops
-type Callback = std::sync::Arc<dyn Fn() + Send + Sync>;
-
-impl Runtime {
- /// Create a new runtime instance with default configuration values.
+ /// The Tokio runtime.
///
- /// This results in a scheduler, I/O driver, and time driver being
- /// initialized. The type of scheduler used depends on what feature flags
- /// are enabled: if the `rt-threaded` feature is enabled, the [threaded
- /// scheduler] is used, while if only the `rt-core` feature is enabled, the
- /// [basic scheduler] is used instead.
+ /// The runtime provides an I/O driver, task scheduler, [timer], and
+ /// blocking pool, necessary for running asynchronous tasks.
///
- /// If the threaded scheduler is selected, it will not spawn
- /// any worker threads until it needs to, i.e. tasks are scheduled to run.
- ///
- /// Most applications will not need to call this function directly. Instead,
- /// they will use the [`#[tokio::main]` attribute][main]. When more complex
- /// configuration is necessary, the [runtime builder] may be used.
+ /// Instances of `Runtime` can be created using [`new`], or [`Builder`].
+ /// However, most users will use the `#[tokio::main]` annotation on their
+ /// entry point instead.
///
/// See [module level][mod] documentation for more details.
///
- /// # Examples
- ///
- /// Creating a new `Runtime` with default configuration values.
+ /// # Shutdown
///
- /// ```
- /// use tokio::runtime::Runtime;
+ /// Shutting down the runtime is done by dropping the value. The current
+ /// thread will block until the shut down operation has completed.
///
- /// let rt = Runtime::new()
- /// .unwrap();
+ /// * Drain any scheduled work queues.
+ /// * Drop any futures that have not yet completed.
+ /// * Drop the reactor.
///
- /// // Use the runtime...
- /// ```
+ /// Once the reactor has dropped, any outstanding I/O resources bound to
+ /// that reactor will no longer function. Calling any method on them will
+ /// result in an error.
///
- /// [mod]: index.html
- /// [main]: ../attr.main.html
- /// [threaded scheduler]: index.html#threaded-scheduler
- /// [basic scheduler]: index.html#basic-scheduler
- /// [runtime builder]: crate::runtime::Builder
- pub fn new() -> io::Result<Runtime> {
- #[cfg(feature = "rt-threaded")]
- let ret = Builder::new().threaded_scheduler().enable_all().build();
-
- #[cfg(all(not(feature = "rt-threaded"), feature = "rt-core"))]
- let ret = Builder::new().basic_scheduler().enable_all().build();
-
- #[cfg(not(feature = "rt-core"))]
- let ret = Builder::new().enable_all().build();
-
- ret
- }
-
- /// Spawn a future onto the Tokio runtime.
+ /// # Sharing
///
- /// This spawns the given future onto the runtime's executor, usually a
- /// thread pool. The thread pool is then responsible for polling the future
- /// until it completes.
+ /// The Tokio runtime implements `Sync` and `Send` to allow you to wrap it
+ /// in a `Arc`. Most fn take `&self` to allow you to call them concurrently
+ /// accross multiple threads.
///
- /// See [module level][mod] documentation for more details.
+ /// Calls to `shutdown` and `shutdown_timeout` require exclusive ownership of
+ /// the runtime type and this can be achieved via `Arc::try_unwrap` when only
+ /// one strong count reference is left over.
///
+ /// [timer]: crate::time
/// [mod]: index.html
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// # fn dox() {
- /// // Create the runtime
- /// let rt = Runtime::new().unwrap();
- ///
- /// // Spawn a future onto the runtime
- /// rt.spawn(async {
- /// println!("now running on a worker thread");
- /// });
- /// # }
- /// ```
- ///
- /// # Panics
- ///
- /// This function will not panic unless task execution is disabled on the
- /// executor. This can only happen if the runtime was built using
- /// [`Builder`] without picking either [`basic_scheduler`] or
- /// [`threaded_scheduler`].
- ///
+ /// [`new`]: method@Self::new
/// [`Builder`]: struct@Builder
- /// [`threaded_scheduler`]: fn@Builder::threaded_scheduler
- /// [`basic_scheduler`]: fn@Builder::basic_scheduler
- #[cfg(feature = "rt-core")]
- pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
- where
- F: Future + Send + 'static,
- F::Output: Send + 'static,
- {
- match &self.kind {
- Kind::Shell(_) => panic!("task execution disabled"),
- #[cfg(feature = "rt-threaded")]
- Kind::ThreadPool(exec) => exec.spawn(future),
- Kind::Basic(exec) => exec.spawn(future),
- }
- }
+ #[derive(Debug)]
+ pub struct Runtime {
+ /// Task executor
+ kind: Kind,
- /// Run a future to completion on the Tokio runtime. This is the runtime's
- /// entry point.
- ///
- /// This runs the given future on the runtime, blocking until it is
- /// complete, and yielding its resolved result. Any tasks or timers which
- /// the future spawns internally will be executed on the runtime.
- ///
- /// `&mut` is required as calling `block_on` **may** result in advancing the
- /// state of the runtime. The details depend on how the runtime is
- /// configured. [`runtime::Handle::block_on`][handle] provides a version
- /// that takes `&self`.
- ///
- /// This method may not be called from an asynchronous context.
- ///
- /// # Panics
- ///
- /// This function panics if the provided future panics, or if called within an
- /// asynchronous execution context.
- ///
- /// # Examples
- ///
- /// ```no_run
- /// use tokio::runtime::Runtime;
- ///
- /// // Create the runtime
- /// let mut rt = Runtime::new().unwrap();
- ///
- /// // Execute the future, blocking the current thread until completion
- /// rt.block_on(async {
- /// println!("hello");
- /// });
- /// ```
- ///
- /// [handle]: fn@Handle::block_on
- pub fn block_on<F: Future>(&mut self, future: F) -> F::Output {
- let kind = &mut self.kind;
+ /// Handle to runtime, also contains driver handles
+ handle: Handle,
- self.handle.enter(|| match kind {
- Kind::Shell(exec) => exec.block_on(future),
- #[cfg(feature = "rt-core")]
- Kind::Basic(exec) => exec.block_on(future),
- #[cfg(feature = "rt-threaded")]
- Kind::ThreadPool(exec) => exec.block_on(future),
- })
+ /// Blocking pool handle, used to signal shutdown
+ blocking_pool: BlockingPool,
}
- /// Enter the runtime context. This allows you to construct types that must
- /// have an executor available on creation such as [`Delay`] or [`TcpStream`].
- /// It will also allow you to call methods such as [`tokio::spawn`].
- ///
- /// This function is also available as [`Handle::enter`].
- ///
- /// [`Delay`]: struct@crate::time::Delay
- /// [`TcpStream`]: struct@crate::net::TcpStream
- /// [`Handle::enter`]: fn@crate::runtime::Handle::enter
- /// [`tokio::spawn`]: fn@crate::spawn
- ///
- /// # Example
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// fn function_that_spawns(msg: String) {
- /// // Had we not used `rt.enter` below, this would panic.
- /// tokio::spawn(async move {
- /// println!("{}", msg);
- /// });
- /// }
+ /// Runtime context guard.
///
- /// fn main() {
- /// let rt = Runtime::new().unwrap();
- ///
- /// let s = "Hello World!".to_string();
- ///
- /// // By entering the context, we tie `tokio::spawn` to this executor.
- /// rt.enter(|| function_that_spawns(s));
- /// }
- /// ```
- pub fn enter<F, R>(&self, f: F) -> R
- where
- F: FnOnce() -> R,
- {
- self.handle.enter(f)
+ /// Returned by [`Runtime::enter`], the context guard exits the runtime
+ /// context on drop.
+ #[derive(Debug)]
+ pub struct EnterGuard<'a> {
+ rt: &'a Runtime,
+ guard: context::EnterGuard,
}
- /// Return a handle to the runtime's spawner.
- ///
- /// The returned handle can be used to spawn tasks that run on this runtime, and can
- /// be cloned to allow moving the `Handle` to other threads.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// let rt = Runtime::new()
- /// .unwrap();
- ///
- /// let handle = rt.handle();
- ///
- /// handle.spawn(async { println!("hello"); });
- /// ```
- pub fn handle(&self) -> &Handle {
- &self.handle
- }
+ /// The runtime executor is either a thread-pool or a current-thread executor.
+ #[derive(Debug)]
+ enum Kind {
+ /// Execute all tasks on the current-thread.
+ CurrentThread(BasicScheduler<driver::Driver>),
- /// Shutdown the runtime, waiting for at most `duration` for all spawned
- /// task to shutdown.
- ///
- /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to
- /// shutdown in a timely fashion. However, dropping a `Runtime` will wait
- /// indefinitely for all tasks to terminate, and there are cases where a long
- /// blocking task has been spawned, which can block dropping `Runtime`.
- ///
- /// In this case, calling `shutdown_timeout` with an explicit wait timeout
- /// can work. The `shutdown_timeout` will signal all tasks to shutdown and
- /// will wait for at most `duration` for all spawned tasks to terminate. If
- /// `timeout` elapses before all tasks are dropped, the function returns and
- /// outstanding tasks are potentially leaked.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- /// use tokio::task;
- ///
- /// use std::thread;
- /// use std::time::Duration;
- ///
- /// fn main() {
- /// let mut runtime = Runtime::new().unwrap();
- ///
- /// runtime.block_on(async move {
- /// task::spawn_blocking(move || {
- /// thread::sleep(Duration::from_secs(10_000));
- /// });
- /// });
- ///
- /// runtime.shutdown_timeout(Duration::from_millis(100));
- /// }
- /// ```
- pub fn shutdown_timeout(self, duration: Duration) {
- let Runtime {
- mut blocking_pool, ..
- } = self;
- blocking_pool.shutdown(Some(duration));
+ /// Execute tasks across multiple threads.
+ #[cfg(feature = "rt-multi-thread")]
+ ThreadPool(ThreadPool),
}
- /// Shutdown the runtime, without waiting for any spawned tasks to shutdown.
- ///
- /// This can be useful if you want to drop a runtime from within another runtime.
- /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks
- /// to complete, which would normally not be permitted within an asynchronous context.
- /// By calling `shutdown_background()`, you can drop the runtime from such a context.
- ///
- /// Note however, that because we do not wait for any blocking tasks to complete, this
- /// may result in a resource leak (in that any blocking tasks are still running until they
- /// return.
- ///
- /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`.
- ///
- /// ```
- /// use tokio::runtime::Runtime;
- ///
- /// fn main() {
- /// let mut runtime = Runtime::new().unwrap();
- ///
- /// runtime.block_on(async move {
- /// let inner_runtime = Runtime::new().unwrap();
- /// // ...
- /// inner_runtime.shutdown_background();
- /// });
- /// }
- /// ```
- pub fn shutdown_background(self) {
- self.shutdown_timeout(Duration::from_nanos(0))
+ /// After thread starts / before thread stops
+ type Callback = std::sync::Arc<dyn Fn() + Send + Sync>;
+
+ impl Runtime {
+ /// Create a new runtime instance with default configuration values.
+ ///
+ /// This results in the multi threaded scheduler, I/O driver, and time driver being
+ /// initialized.
+ ///
+ /// Most applications will not need to call this function directly. Instead,
+ /// they will use the [`#[tokio::main]` attribute][main]. When a more complex
+ /// configuration is necessary, the [runtime builder] may be used.
+ ///
+ /// See [module level][mod] documentation for more details.
+ ///
+ /// # Examples
+ ///
+ /// Creating a new `Runtime` with default configuration values.
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ ///
+ /// let rt = Runtime::new()
+ /// .unwrap();
+ ///
+ /// // Use the runtime...
+ /// ```
+ ///
+ /// [mod]: index.html
+ /// [main]: ../attr.main.html
+ /// [threaded scheduler]: index.html#threaded-scheduler
+ /// [basic scheduler]: index.html#basic-scheduler
+ /// [runtime builder]: crate::runtime::Builder
+ #[cfg(feature = "rt-multi-thread")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt-multi-thread")))]
+ pub fn new() -> std::io::Result<Runtime> {
+ Builder::new_multi_thread().enable_all().build()
+ }
+
+ /// Spawn a future onto the Tokio runtime.
+ ///
+ /// This spawns the given future onto the runtime's executor, usually a
+ /// thread pool. The thread pool is then responsible for polling the future
+ /// until it completes.
+ ///
+ /// See [module level][mod] documentation for more details.
+ ///
+ /// [mod]: index.html
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ ///
+ /// # fn dox() {
+ /// // Create the runtime
+ /// let rt = Runtime::new().unwrap();
+ ///
+ /// // Spawn a future onto the runtime
+ /// rt.spawn(async {
+ /// println!("now running on a worker thread");
+ /// });
+ /// # }
+ /// ```
+ pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
+ where
+ F: Future + Send + 'static,
+ F::Output: Send + 'static,
+ {
+ match &self.kind {
+ #[cfg(feature = "rt-multi-thread")]
+ Kind::ThreadPool(exec) => exec.spawn(future),
+ Kind::CurrentThread(exec) => exec.spawn(future),
+ }
+ }
+
+ /// Run the provided function on an executor dedicated to blocking operations.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ ///
+ /// # fn dox() {
+ /// // Create the runtime
+ /// let rt = Runtime::new().unwrap();
+ ///
+ /// // Spawn a blocking function onto the runtime
+ /// rt.spawn_blocking(|| {
+ /// println!("now running on a worker thread");
+ /// });
+ /// # }
+ pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
+ where
+ F: FnOnce() -> R + Send + 'static,
+ {
+ self.handle.spawn_blocking(func)
+ }
+
+ /// Run a future to completion on the Tokio runtime. This is the
+ /// runtime's entry point.
+ ///
+ /// This runs the given future on the runtime, blocking until it is
+ /// complete, and yielding its resolved result. Any tasks or timers
+ /// which the future spawns internally will be executed on the runtime.
+ ///
+ /// # Multi thread scheduler
+ ///
+ /// When the multi thread scheduler is used this will allow futures
+ /// to run within the io driver and timer context of the overall runtime.
+ ///
+ /// # Current thread scheduler
+ ///
+ /// When the current thread scheduler is enabled `block_on`
+ /// can be called concurrently from multiple threads. The first call
+ /// will take ownership of the io and timer drivers. This means
+ /// other threads which do not own the drivers will hook into that one.
+ /// When the first `block_on` completes, other threads will be able to
+ /// "steal" the driver to allow continued execution of their futures.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if the provided future panics, or if not called within an
+ /// asynchronous execution context.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::runtime::Runtime;
+ ///
+ /// // Create the runtime
+ /// let rt = Runtime::new().unwrap();
+ ///
+ /// // Execute the future, blocking the current thread until completion
+ /// rt.block_on(async {
+ /// println!("hello");
+ /// });
+ /// ```
+ ///
+ /// [handle]: fn@Handle::block_on
+ pub fn block_on<F: Future>(&self, future: F) -> F::Output {
+ let _enter = self.enter();
+
+ match &self.kind {
+ Kind::CurrentThread(exec) => exec.block_on(future),
+ #[cfg(feature = "rt-multi-thread")]
+ Kind::ThreadPool(exec) => exec.block_on(future),
+ }
+ }
+
+ /// Enter the runtime context.
+ ///
+ /// This allows you to construct types that must have an executor
+ /// available on creation such as [`Sleep`] or [`TcpStream`]. It will
+ /// also allow you to call methods such as [`tokio::spawn`].
+ ///
+ /// [`Sleep`]: struct@crate::time::Sleep
+ /// [`TcpStream`]: struct@crate::net::TcpStream
+ /// [`tokio::spawn`]: fn@crate::spawn
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ ///
+ /// fn function_that_spawns(msg: String) {
+ /// // Had we not used `rt.enter` below, this would panic.
+ /// tokio::spawn(async move {
+ /// println!("{}", msg);
+ /// });
+ /// }
+ ///
+ /// fn main() {
+ /// let rt = Runtime::new().unwrap();
+ ///
+ /// let s = "Hello World!".to_string();
+ ///
+ /// // By entering the context, we tie `tokio::spawn` to this executor.
+ /// let _guard = rt.enter();
+ /// function_that_spawns(s);
+ /// }
+ /// ```
+ pub fn enter(&self) -> EnterGuard<'_> {
+ EnterGuard {
+ rt: self,
+ guard: context::enter(self.handle.clone()),
+ }
+ }
+
+ /// Shutdown the runtime, waiting for at most `duration` for all spawned
+ /// task to shutdown.
+ ///
+ /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to
+ /// shutdown in a timely fashion. However, dropping a `Runtime` will wait
+ /// indefinitely for all tasks to terminate, and there are cases where a long
+ /// blocking task has been spawned, which can block dropping `Runtime`.
+ ///
+ /// In this case, calling `shutdown_timeout` with an explicit wait timeout
+ /// can work. The `shutdown_timeout` will signal all tasks to shutdown and
+ /// will wait for at most `duration` for all spawned tasks to terminate. If
+ /// `timeout` elapses before all tasks are dropped, the function returns and
+ /// outstanding tasks are potentially leaked.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ /// use tokio::task;
+ ///
+ /// use std::thread;
+ /// use std::time::Duration;
+ ///
+ /// fn main() {
+ /// let runtime = Runtime::new().unwrap();
+ ///
+ /// runtime.block_on(async move {
+ /// task::spawn_blocking(move || {
+ /// thread::sleep(Duration::from_secs(10_000));
+ /// });
+ /// });
+ ///
+ /// runtime.shutdown_timeout(Duration::from_millis(100));
+ /// }
+ /// ```
+ pub fn shutdown_timeout(mut self, duration: Duration) {
+ // Wakeup and shutdown all the worker threads
+ self.handle.spawner.shutdown();
+ self.blocking_pool.shutdown(Some(duration));
+ }
+
+ /// Shutdown the runtime, without waiting for any spawned tasks to shutdown.
+ ///
+ /// This can be useful if you want to drop a runtime from within another runtime.
+ /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks
+ /// to complete, which would normally not be permitted within an asynchronous context.
+ /// By calling `shutdown_background()`, you can drop the runtime from such a context.
+ ///
+ /// Note however, that because we do not wait for any blocking tasks to complete, this
+ /// may result in a resource leak (in that any blocking tasks are still running until they
+ /// return.
+ ///
+ /// This function is equivalent to calling `shutdown_timeout(Duration::of_nanos(0))`.
+ ///
+ /// ```
+ /// use tokio::runtime::Runtime;
+ ///
+ /// fn main() {
+ /// let runtime = Runtime::new().unwrap();
+ ///
+ /// runtime.block_on(async move {
+ /// let inner_runtime = Runtime::new().unwrap();
+ /// // ...
+ /// inner_runtime.shutdown_background();
+ /// });
+ /// }
+ /// ```
+ pub fn shutdown_background(self) {
+ self.shutdown_timeout(Duration::from_nanos(0))
+ }
}
}
diff --git a/src/runtime/park.rs b/src/runtime/park.rs
index ee437d1..033b9f2 100644
--- a/src/runtime/park.rs
+++ b/src/runtime/park.rs
@@ -6,7 +6,7 @@ use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::loom::thread;
use crate::park::{Park, Unpark};
-use crate::runtime::time;
+use crate::runtime::driver::Driver;
use crate::util::TryLock;
use std::sync::atomic::Ordering::SeqCst;
@@ -42,14 +42,14 @@ const NOTIFIED: usize = 3;
/// Shared across multiple Parker handles
struct Shared {
/// Shared driver. Only one thread at a time can use this
- driver: TryLock<time::Driver>,
+ driver: TryLock<Driver>,
/// Unpark handle
- handle: <time::Driver as Park>::Unpark,
+ handle: <Driver as Park>::Unpark,
}
impl Parker {
- pub(crate) fn new(driver: time::Driver) -> Parker {
+ pub(crate) fn new(driver: Driver) -> Parker {
let handle = driver.unpark();
Parker {
@@ -104,6 +104,10 @@ impl Park for Parker {
Ok(())
}
}
+
+ fn shutdown(&mut self) {
+ self.inner.shutdown();
+ }
}
impl Unpark for Unparker {
@@ -138,7 +142,7 @@ impl Inner {
fn park_condvar(&self) {
// Otherwise we need to coordinate going to sleep
- let mut m = self.mutex.lock().unwrap();
+ let mut m = self.mutex.lock();
match self
.state
@@ -176,7 +180,7 @@ impl Inner {
}
}
- fn park_driver(&self, driver: &mut time::Driver) {
+ fn park_driver(&self, driver: &mut Driver) {
match self
.state
.compare_exchange(EMPTY, PARKED_DRIVER, SeqCst, SeqCst)
@@ -234,7 +238,7 @@ impl Inner {
// Releasing `lock` before the call to `notify_one` means that when the
// parked thread wakes it doesn't get woken only to have to wait for us
// to release `lock`.
- drop(self.mutex.lock().unwrap());
+ drop(self.mutex.lock());
self.condvar.notify_one()
}
@@ -242,4 +246,12 @@ impl Inner {
fn unpark_driver(&self) {
self.shared.handle.unpark();
}
+
+ fn shutdown(&self) {
+ if let Some(mut driver) = self.shared.driver.try_lock() {
+ driver.shutdown();
+ }
+
+ self.condvar.notify_all();
+ }
}
diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs
index c654514..cdf4009 100644
--- a/src/runtime/queue.rs
+++ b/src/runtime/queue.rs
@@ -481,7 +481,7 @@ impl<T: 'static> Inject<T> {
/// Close the injection queue, returns `true` if the queue is open when the
/// transition is made.
pub(super) fn close(&self) -> bool {
- let mut p = self.pointers.lock().unwrap();
+ let mut p = self.pointers.lock();
if p.is_closed {
return false;
@@ -492,7 +492,7 @@ impl<T: 'static> Inject<T> {
}
pub(super) fn is_closed(&self) -> bool {
- self.pointers.lock().unwrap().is_closed
+ self.pointers.lock().is_closed
}
pub(super) fn len(&self) -> usize {
@@ -502,7 +502,7 @@ impl<T: 'static> Inject<T> {
/// Pushes a value into the queue.
pub(super) fn push(&self, task: task::Notified<T>) {
// Acquire queue lock
- let mut p = self.pointers.lock().unwrap();
+ let mut p = self.pointers.lock();
if p.is_closed {
// Drop the mutex to avoid a potential deadlock when
@@ -541,7 +541,7 @@ impl<T: 'static> Inject<T> {
debug_assert!(get_next(batch_tail).is_none());
- let mut p = self.pointers.lock().unwrap();
+ let mut p = self.pointers.lock();
if let Some(tail) = p.tail {
set_next(tail, Some(batch_head));
@@ -566,7 +566,7 @@ impl<T: 'static> Inject<T> {
return None;
}
- let mut p = self.pointers.lock().unwrap();
+ let mut p = self.pointers.lock();
// It is possible to hit null here if another thread poped the last
// task between us checking `len` and acquiring the lock.
diff --git a/src/runtime/shell.rs b/src/runtime/shell.rs
index a65869d..486d4fa 100644
--- a/src/runtime/shell.rs
+++ b/src/runtime/shell.rs
@@ -1,52 +1,84 @@
#![allow(clippy::redundant_clone)]
+use crate::future::poll_fn;
use crate::park::{Park, Unpark};
-use crate::runtime::enter;
-use crate::runtime::time;
+use crate::runtime::driver::Driver;
+use crate::sync::Notify;
use crate::util::{waker_ref, Wake};
-use std::future::Future;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
use std::task::Context;
-use std::task::Poll::Ready;
+use std::task::Poll::{Pending, Ready};
+use std::{future::Future, sync::PoisonError};
#[derive(Debug)]
pub(super) struct Shell {
- driver: time::Driver,
+ driver: Mutex<Option<Driver>>,
+
+ notify: Notify,
/// TODO: don't store this
unpark: Arc<Handle>,
}
#[derive(Debug)]
-struct Handle(<time::Driver as Park>::Unpark);
+struct Handle(<Driver as Park>::Unpark);
impl Shell {
- pub(super) fn new(driver: time::Driver) -> Shell {
+ pub(super) fn new(driver: Driver) -> Shell {
let unpark = Arc::new(Handle(driver.unpark()));
- Shell { driver, unpark }
+ Shell {
+ driver: Mutex::new(Some(driver)),
+ notify: Notify::new(),
+ unpark,
+ }
}
- pub(super) fn block_on<F>(&mut self, f: F) -> F::Output
+ pub(super) fn block_on<F>(&self, f: F) -> F::Output
where
F: Future,
{
- let _e = enter(true);
+ let mut enter = crate::runtime::enter(true);
pin!(f);
- let waker = waker_ref(&self.unpark);
- let mut cx = Context::from_waker(&waker);
-
loop {
- if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
- return v;
- }
+ if let Some(driver) = &mut self.take_driver() {
+ return driver.block_on(f);
+ } else {
+ let notified = self.notify.notified();
+ pin!(notified);
+
+ if let Some(out) = enter
+ .block_on(poll_fn(|cx| {
+ if notified.as_mut().poll(cx).is_ready() {
+ return Ready(None);
+ }
+
+ if let Ready(out) = f.as_mut().poll(cx) {
+ return Ready(Some(out));
+ }
- self.driver.park().unwrap();
+ Pending
+ }))
+ .expect("Failed to `Enter::block_on`")
+ {
+ return out;
+ }
+ }
}
}
+
+ fn take_driver(&self) -> Option<DriverGuard<'_>> {
+ let mut lock = self.driver.lock().unwrap();
+ let driver = lock.take()?;
+
+ Some(DriverGuard {
+ inner: Some(driver),
+ shell: &self,
+ })
+ }
}
impl Wake for Handle {
@@ -60,3 +92,41 @@ impl Wake for Handle {
arc_self.0.unpark();
}
}
+
+struct DriverGuard<'a> {
+ inner: Option<Driver>,
+ shell: &'a Shell,
+}
+
+impl DriverGuard<'_> {
+ fn block_on<F: Future>(&mut self, f: F) -> F::Output {
+ let driver = self.inner.as_mut().unwrap();
+
+ pin!(f);
+
+ let waker = waker_ref(&self.shell.unpark);
+ let mut cx = Context::from_waker(&waker);
+
+ loop {
+ if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
+ return v;
+ }
+
+ driver.park().unwrap();
+ }
+ }
+}
+
+impl Drop for DriverGuard<'_> {
+ fn drop(&mut self) {
+ if let Some(inner) = self.inner.take() {
+ self.shell
+ .driver
+ .lock()
+ .unwrap_or_else(PoisonError::into_inner)
+ .replace(inner);
+
+ self.shell.notify.notify_one();
+ }
+ }
+}
diff --git a/src/runtime/spawner.rs b/src/runtime/spawner.rs
index d136945..a37c667 100644
--- a/src/runtime/spawner.rs
+++ b/src/runtime/spawner.rs
@@ -1,24 +1,34 @@
-cfg_rt_core! {
+cfg_rt! {
use crate::runtime::basic_scheduler;
use crate::task::JoinHandle;
use std::future::Future;
}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
use crate::runtime::thread_pool;
}
#[derive(Debug, Clone)]
pub(crate) enum Spawner {
- Shell,
- #[cfg(feature = "rt-core")]
+ #[cfg(feature = "rt")]
Basic(basic_scheduler::Spawner),
- #[cfg(feature = "rt-threaded")]
+ #[cfg(feature = "rt-multi-thread")]
ThreadPool(thread_pool::Spawner),
}
-cfg_rt_core! {
+impl Spawner {
+ pub(crate) fn shutdown(&mut self) {
+ #[cfg(feature = "rt-multi-thread")]
+ {
+ if let Spawner::ThreadPool(spawner) = self {
+ spawner.shutdown();
+ }
+ }
+ }
+}
+
+cfg_rt! {
impl Spawner {
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
@@ -26,10 +36,9 @@ cfg_rt_core! {
F::Output: Send + 'static,
{
match self {
- Spawner::Shell => panic!("spawning not enabled for runtime"),
- #[cfg(feature = "rt-core")]
+ #[cfg(feature = "rt")]
Spawner::Basic(spawner) => spawner.spawn(future),
- #[cfg(feature = "rt-threaded")]
+ #[cfg(feature = "rt-multi-thread")]
Spawner::ThreadPool(spawner) => spawner.spawn(future),
}
}
diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs
index f4756c2..dfa8764 100644
--- a/src/runtime/task/core.rs
+++ b/src/runtime/task/core.rs
@@ -269,7 +269,7 @@ impl<T: Future, S: Schedule> Core<T, S> {
}
}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
impl Header {
pub(crate) fn shutdown(&self) {
use crate::runtime::task::RawTask;
diff --git a/src/runtime/task/error.rs b/src/runtime/task/error.rs
index d5f65a4..177fe65 100644
--- a/src/runtime/task/error.rs
+++ b/src/runtime/task/error.rs
@@ -3,7 +3,7 @@ use std::fmt;
use std::io;
use std::sync::Mutex;
-doc_rt_core! {
+cfg_rt! {
/// Task failed to execute to completion.
pub struct JoinError {
repr: Repr,
@@ -16,25 +16,13 @@ enum Repr {
}
impl JoinError {
- #[doc(hidden)]
- #[deprecated]
- pub fn cancelled() -> JoinError {
- Self::cancelled2()
- }
-
- pub(crate) fn cancelled2() -> JoinError {
+ pub(crate) fn cancelled() -> JoinError {
JoinError {
repr: Repr::Cancelled,
}
}
- #[doc(hidden)]
- #[deprecated]
- pub fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError {
- Self::panic2(err)
- }
-
- pub(crate) fn panic2(err: Box<dyn Any + Send + 'static>) -> JoinError {
+ pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError {
JoinError {
repr: Repr::Panic(Mutex::new(err)),
}
@@ -42,10 +30,7 @@ impl JoinError {
/// Returns true if the error was caused by the task being cancelled
pub fn is_cancelled(&self) -> bool {
- match &self.repr {
- Repr::Cancelled => true,
- _ => false,
- }
+ matches!(&self.repr, Repr::Cancelled)
}
/// Returns true if the error was caused by the task panicking
@@ -65,10 +50,7 @@ impl JoinError {
/// }
/// ```
pub fn is_panic(&self) -> bool {
- match &self.repr {
- Repr::Panic(_) => true,
- _ => false,
- }
+ matches!(&self.repr, Repr::Panic(_))
}
/// Consumes the join error, returning the object with which the task panicked.
diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs
index e86b29e..208d48c 100644
--- a/src/runtime/task/harness.rs
+++ b/src/runtime/task/harness.rs
@@ -102,7 +102,7 @@ where
// If the task is cancelled, avoid polling it, instead signalling it
// is complete.
if snapshot.is_cancelled() {
- Poll::Ready(Err(JoinError::cancelled2()))
+ Poll::Ready(Err(JoinError::cancelled()))
} else {
let res = guard.core.poll(self.header());
@@ -132,7 +132,7 @@ where
}
}
Err(err) => {
- self.complete(Err(JoinError::panic2(err)), snapshot.is_join_interested());
+ self.complete(Err(JoinError::panic(err)), snapshot.is_join_interested());
}
}
}
@@ -297,9 +297,9 @@ where
// Dropping the future panicked, complete the join
// handle with the panic to avoid dropping the panic
// on the ground.
- self.complete(Err(JoinError::panic2(err)), true);
+ self.complete(Err(JoinError::panic(err)), true);
} else {
- self.complete(Err(JoinError::cancelled2()), true);
+ self.complete(Err(JoinError::cancelled()), true);
}
}
diff --git a/src/runtime/task/join.rs b/src/runtime/task/join.rs
index 3c4aabb..dedfb38 100644
--- a/src/runtime/task/join.rs
+++ b/src/runtime/task/join.rs
@@ -6,7 +6,7 @@ use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
-doc_rt_core! {
+cfg_rt! {
/// An owned permission to join on a task (await its termination).
///
/// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for
@@ -45,6 +45,71 @@ doc_rt_core! {
/// # }
/// ```
///
+ /// The generic parameter `T` in `JoinHandle<T>` is the return type of the spawned task.
+ /// If the return value is an i32, the join handle has type `JoinHandle<i32>`:
+ ///
+ /// ```
+ /// use tokio::task;
+ ///
+ /// # async fn doc() {
+ /// let join_handle: task::JoinHandle<i32> = task::spawn(async {
+ /// 5 + 3
+ /// });
+ /// # }
+ ///
+ /// ```
+ ///
+ /// If the task does not have a return value, the join handle has type `JoinHandle<()>`:
+ ///
+ /// ```
+ /// use tokio::task;
+ ///
+ /// # async fn doc() {
+ /// let join_handle: task::JoinHandle<()> = task::spawn(async {
+ /// println!("I return nothing.");
+ /// });
+ /// # }
+ /// ```
+ ///
+ /// Note that `handle.await` doesn't give you the return type directly. It is wrapped in a
+ /// `Result` because panics in the spawned task are caught by Tokio. The `?` operator has
+ /// to be double chained to extract the returned value:
+ ///
+ /// ```
+ /// use tokio::task;
+ /// use std::io;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async {
+ /// Ok(5 + 3)
+ /// });
+ ///
+ /// let result = join_handle.await??;
+ /// assert_eq!(result, 8);
+ /// Ok(())
+ /// }
+ /// ```
+ ///
+ /// If the task panics, the error is a [`JoinError`] that contains the panic:
+ ///
+ /// ```
+ /// use tokio::task;
+ /// use std::io;
+ /// use std::panic;
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let join_handle: task::JoinHandle<Result<i32, io::Error>> = tokio::spawn(async {
+ /// panic!("boom");
+ /// });
+ ///
+ /// let err = join_handle.await.unwrap_err();
+ /// assert!(err.is_panic());
+ /// Ok(())
+ /// }
+ ///
+ /// ```
/// Child being detached and outliving its parent:
///
/// ```no_run
@@ -56,7 +121,7 @@ doc_rt_core! {
/// let original_task = task::spawn(async {
/// let _detached_task = task::spawn(async {
/// // Here we sleep to make sure that the first task returns before.
- /// time::delay_for(Duration::from_millis(10)).await;
+ /// time::sleep(Duration::from_millis(10)).await;
/// // This will be called, even though the JoinHandle is dropped.
/// println!("♫ Still alive ♫");
/// });
@@ -68,13 +133,14 @@ doc_rt_core! {
/// // We make sure that the new task has time to run, before the main
/// // task returns.
///
- /// time::delay_for(Duration::from_millis(1000)).await;
+ /// time::sleep(Duration::from_millis(1000)).await;
/// # }
/// ```
///
/// [`task::spawn`]: crate::task::spawn()
/// [`task::spawn_blocking`]: crate::task::spawn_blocking
/// [`std::thread::JoinHandle`]: std::thread::JoinHandle
+ /// [`JoinError`]: crate::task::JoinError
pub struct JoinHandle<T> {
raw: Option<RawTask>,
_p: PhantomData<T>,
@@ -91,6 +157,44 @@ impl<T> JoinHandle<T> {
_p: PhantomData,
}
}
+
+ /// Abort the task associated with the handle.
+ ///
+ /// Awaiting a cancelled task might complete as usual if the task was
+ /// already completed at the time it was cancelled, but most likely it
+ /// will complete with a `Err(JoinError::Cancelled)`.
+ ///
+ /// ```rust
+ /// use tokio::time;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let mut handles = Vec::new();
+ ///
+ /// handles.push(tokio::spawn(async {
+ /// time::sleep(time::Duration::from_secs(10)).await;
+ /// true
+ /// }));
+ ///
+ /// handles.push(tokio::spawn(async {
+ /// time::sleep(time::Duration::from_secs(10)).await;
+ /// false
+ /// }));
+ ///
+ /// for handle in &handles {
+ /// handle.abort();
+ /// }
+ ///
+ /// for handle in handles {
+ /// assert!(handle.await.unwrap_err().is_cancelled());
+ /// }
+ /// }
+ /// ```
+ pub fn abort(&self) {
+ if let Some(raw) = self.raw {
+ raw.shutdown();
+ }
+ }
}
impl<T> Unpin for JoinHandle<T> {}
diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs
index 17b5157..7b49e95 100644
--- a/src/runtime/task/mod.rs
+++ b/src/runtime/task/mod.rs
@@ -21,7 +21,7 @@ use self::state::State;
mod waker;
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
mod stack;
pub(crate) use self::stack::TransferStack;
}
@@ -79,25 +79,27 @@ pub(crate) trait Schedule: Sync + Sized + 'static {
}
}
-/// Create a new task with an associated join handle
-pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>)
-where
- T: Future + Send + 'static,
- S: Schedule,
-{
- let raw = RawTask::new::<_, S>(task);
+cfg_rt! {
+ /// Create a new task with an associated join handle
+ pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>)
+ where
+ T: Future + Send + 'static,
+ S: Schedule,
+ {
+ let raw = RawTask::new::<_, S>(task);
- let task = Task {
- raw,
- _p: PhantomData,
- };
+ let task = Task {
+ raw,
+ _p: PhantomData,
+ };
- let join = JoinHandle::new(raw);
+ let join = JoinHandle::new(raw);
- (Notified(task), join)
+ (Notified(task), join)
+ }
}
-cfg_rt_util! {
+cfg_rt! {
/// Create a new `!Send` task with an associated join handle
pub(crate) unsafe fn joinable_local<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>)
where
@@ -130,7 +132,7 @@ impl<S: 'static> Task<S> {
}
}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
impl<S: 'static> Notified<S> {
pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> {
Notified(Task::from_raw(ptr))
diff --git a/src/runtime/tests/loom_blocking.rs b/src/runtime/tests/loom_blocking.rs
index db7048e..8fb54c5 100644
--- a/src/runtime/tests/loom_blocking.rs
+++ b/src/runtime/tests/loom_blocking.rs
@@ -8,14 +8,15 @@ fn blocking_shutdown() {
let v = Arc::new(());
let rt = mk_runtime(1);
- rt.enter(|| {
+ {
+ let _enter = rt.enter();
for _ in 0..2 {
let v = v.clone();
crate::task::spawn_blocking(move || {
assert!(1 < Arc::strong_count(&v));
});
}
- });
+ }
drop(rt);
assert_eq!(1, Arc::strong_count(&v));
@@ -23,9 +24,8 @@ fn blocking_shutdown() {
}
fn mk_runtime(num_threads: usize) -> Runtime {
- runtime::Builder::new()
- .threaded_scheduler()
- .core_threads(num_threads)
+ runtime::Builder::new_multi_thread()
+ .worker_threads(num_threads)
.build()
.unwrap()
}
diff --git a/src/runtime/tests/loom_pool.rs b/src/runtime/tests/loom_pool.rs
index c08658c..06ad641 100644
--- a/src/runtime/tests/loom_pool.rs
+++ b/src/runtime/tests/loom_pool.rs
@@ -178,7 +178,7 @@ mod group_b {
#[test]
fn join_output() {
loom::model(|| {
- let mut rt = mk_pool(1);
+ let rt = mk_pool(1);
rt.block_on(async {
let t = crate::spawn(track(async { "hello" }));
@@ -192,7 +192,7 @@ mod group_b {
#[test]
fn poll_drop_handle_then_drop() {
loom::model(|| {
- let mut rt = mk_pool(1);
+ let rt = mk_pool(1);
rt.block_on(async move {
let mut t = crate::spawn(track(async { "hello" }));
@@ -209,7 +209,7 @@ mod group_b {
#[test]
fn complete_block_on_under_load() {
loom::model(|| {
- let mut pool = mk_pool(1);
+ let pool = mk_pool(1);
pool.block_on(async {
// Trigger a re-schedule
@@ -296,9 +296,8 @@ mod group_d {
}
fn mk_pool(num_threads: usize) -> Runtime {
- runtime::Builder::new()
- .threaded_scheduler()
- .core_threads(num_threads)
+ runtime::Builder::new_multi_thread()
+ .worker_threads(num_threads)
.build()
.unwrap()
}
diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs
index 82315a0..a34526f 100644
--- a/src/runtime/tests/task.rs
+++ b/src/runtime/tests/task.rs
@@ -1,5 +1,5 @@
use crate::runtime::task::{self, Schedule, Task};
-use crate::util::linked_list::LinkedList;
+use crate::util::linked_list::{Link, LinkedList};
use crate::util::TryLock;
use std::collections::VecDeque;
@@ -72,7 +72,7 @@ struct Inner {
struct Core {
queue: VecDeque<task::Notified<Runtime>>,
- tasks: LinkedList<Task<Runtime>>,
+ tasks: LinkedList<Task<Runtime>, <Task<Runtime> as Link>::Target>,
}
static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None);
diff --git a/src/runtime/thread_pool/atomic_cell.rs b/src/runtime/thread_pool/atomic_cell.rs
index 2bda0fc..98847e6 100644
--- a/src/runtime/thread_pool/atomic_cell.rs
+++ b/src/runtime/thread_pool/atomic_cell.rs
@@ -22,7 +22,6 @@ impl<T> AtomicCell<T> {
from_raw(old)
}
- #[cfg(feature = "blocking")]
pub(super) fn set(&self, val: Box<T>) {
let _ = self.swap(Some(val));
}
diff --git a/src/runtime/thread_pool/idle.rs b/src/runtime/thread_pool/idle.rs
index ae87ca4..6e692fd 100644
--- a/src/runtime/thread_pool/idle.rs
+++ b/src/runtime/thread_pool/idle.rs
@@ -55,7 +55,7 @@ impl Idle {
}
// Acquire the lock
- let mut sleepers = self.sleepers.lock().unwrap();
+ let mut sleepers = self.sleepers.lock();
// Check again, now that the lock is acquired
if !self.notify_should_wakeup() {
@@ -77,7 +77,7 @@ impl Idle {
/// work.
pub(super) fn transition_worker_to_parked(&self, worker: usize, is_searching: bool) -> bool {
// Acquire the lock
- let mut sleepers = self.sleepers.lock().unwrap();
+ let mut sleepers = self.sleepers.lock();
// Decrement the number of unparked threads
let ret = State::dec_num_unparked(&self.state, is_searching);
@@ -112,7 +112,7 @@ impl Idle {
/// Unpark a specific worker. This happens if tasks are submitted from
/// within the worker's park routine.
pub(super) fn unpark_worker_by_id(&self, worker_id: usize) {
- let mut sleepers = self.sleepers.lock().unwrap();
+ let mut sleepers = self.sleepers.lock();
for index in 0..sleepers.len() {
if sleepers[index] == worker_id {
@@ -128,7 +128,7 @@ impl Idle {
/// Returns `true` if `worker_id` is contained in the sleep set
pub(super) fn is_parked(&self, worker_id: usize) -> bool {
- let sleepers = self.sleepers.lock().unwrap();
+ let sleepers = self.sleepers.lock();
sleepers.contains(&worker_id)
}
diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs
index ced9712..e39695a 100644
--- a/src/runtime/thread_pool/mod.rs
+++ b/src/runtime/thread_pool/mod.rs
@@ -9,9 +9,7 @@ use self::idle::Idle;
mod worker;
pub(crate) use worker::Launch;
-cfg_blocking! {
- pub(crate) use worker::block_in_place;
-}
+pub(crate) use worker::block_in_place;
use crate::loom::sync::Arc;
use crate::runtime::task::{self, JoinHandle};
@@ -91,7 +89,7 @@ impl fmt::Debug for ThreadPool {
impl Drop for ThreadPool {
fn drop(&mut self) {
- self.spawner.shared.close();
+ self.spawner.shutdown();
}
}
@@ -108,6 +106,10 @@ impl Spawner {
self.shared.schedule(task, false);
handle
}
+
+ pub(crate) fn shutdown(&mut self) {
+ self.shared.close();
+ }
}
impl fmt::Debug for Spawner {
diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs
index abe20da..bc544c9 100644
--- a/src/runtime/thread_pool/worker.rs
+++ b/src/runtime/thread_pool/worker.rs
@@ -9,10 +9,11 @@ use crate::loom::rand::seed;
use crate::loom::sync::{Arc, Mutex};
use crate::park::{Park, Unpark};
use crate::runtime;
+use crate::runtime::enter::EnterContext;
use crate::runtime::park::{Parker, Unparker};
use crate::runtime::thread_pool::{AtomicCell, Idle};
use crate::runtime::{queue, task};
-use crate::util::linked_list::LinkedList;
+use crate::util::linked_list::{Link, LinkedList};
use crate::util::FastRand;
use std::cell::RefCell;
@@ -53,7 +54,7 @@ struct Core {
is_shutdown: bool,
/// Tasks owned by the core
- tasks: LinkedList<Task>,
+ tasks: LinkedList<Task, <Task as Link>::Target>,
/// Parker
///
@@ -172,104 +173,96 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) {
(shared, launch)
}
-cfg_blocking! {
- use crate::runtime::enter::EnterContext;
-
- pub(crate) fn block_in_place<F, R>(f: F) -> R
- where
- F: FnOnce() -> R,
- {
- // Try to steal the worker core back
- struct Reset(coop::Budget);
-
- impl Drop for Reset {
- fn drop(&mut self) {
- CURRENT.with(|maybe_cx| {
- if let Some(cx) = maybe_cx {
- let core = cx.worker.core.take();
- let mut cx_core = cx.core.borrow_mut();
- assert!(cx_core.is_none());
- *cx_core = core;
-
- // Reset the task budget as we are re-entering the
- // runtime.
- coop::set(self.0);
- }
- });
- }
+pub(crate) fn block_in_place<F, R>(f: F) -> R
+where
+ F: FnOnce() -> R,
+{
+ // Try to steal the worker core back
+ struct Reset(coop::Budget);
+
+ impl Drop for Reset {
+ fn drop(&mut self) {
+ CURRENT.with(|maybe_cx| {
+ if let Some(cx) = maybe_cx {
+ let core = cx.worker.core.take();
+ let mut cx_core = cx.core.borrow_mut();
+ assert!(cx_core.is_none());
+ *cx_core = core;
+
+ // Reset the task budget as we are re-entering the
+ // runtime.
+ coop::set(self.0);
+ }
+ });
}
+ }
- let mut had_core = false;
- let mut had_entered = false;
+ let mut had_entered = false;
- CURRENT.with(|maybe_cx| {
- match (crate::runtime::enter::context(), maybe_cx.is_some()) {
- (EnterContext::Entered { .. }, true) => {
- // We are on a thread pool runtime thread, so we just need to set up blocking.
+ CURRENT.with(|maybe_cx| {
+ match (crate::runtime::enter::context(), maybe_cx.is_some()) {
+ (EnterContext::Entered { .. }, true) => {
+ // We are on a thread pool runtime thread, so we just need to set up blocking.
+ had_entered = true;
+ }
+ (EnterContext::Entered { allow_blocking }, false) => {
+ // We are on an executor, but _not_ on the thread pool.
+ // That is _only_ okay if we are in a thread pool runtime's block_on method:
+ if allow_blocking {
had_entered = true;
- }
- (EnterContext::Entered { allow_blocking }, false) => {
- // We are on an executor, but _not_ on the thread pool.
- // That is _only_ okay if we are in a thread pool runtime's block_on method:
- if allow_blocking {
- had_entered = true;
- return;
- } else {
- // This probably means we are on the basic_scheduler or in a LocalSet,
- // where it is _not_ okay to block.
- panic!("can call blocking only when running on the multi-threaded runtime");
- }
- }
- (EnterContext::NotEntered, true) => {
- // This is a nested call to block_in_place (we already exited).
- // All the necessary setup has already been done.
- return;
- }
- (EnterContext::NotEntered, false) => {
- // We are outside of the tokio runtime, so blocking is fine.
- // We can also skip all of the thread pool blocking setup steps.
return;
+ } else {
+ // This probably means we are on the basic_scheduler or in a LocalSet,
+ // where it is _not_ okay to block.
+ panic!("can call blocking only when running on the multi-threaded runtime");
}
}
+ (EnterContext::NotEntered, true) => {
+ // This is a nested call to block_in_place (we already exited).
+ // All the necessary setup has already been done.
+ return;
+ }
+ (EnterContext::NotEntered, false) => {
+ // We are outside of the tokio runtime, so blocking is fine.
+ // We can also skip all of the thread pool blocking setup steps.
+ return;
+ }
+ }
- let cx = maybe_cx.expect("no .is_some() == false cases above should lead here");
-
- // Get the worker core. If none is set, then blocking is fine!
- let core = match cx.core.borrow_mut().take() {
- Some(core) => core,
- None => return,
- };
-
- // The parker should be set here
- assert!(core.park.is_some());
+ let cx = maybe_cx.expect("no .is_some() == false cases above should lead here");
- // In order to block, the core must be sent to another thread for
- // execution.
- //
- // First, move the core back into the worker's shared core slot.
- cx.worker.core.set(core);
- had_core = true;
+ // Get the worker core. If none is set, then blocking is fine!
+ let core = match cx.core.borrow_mut().take() {
+ Some(core) => core,
+ None => return,
+ };
- // Next, clone the worker handle and send it to a new thread for
- // processing.
- //
- // Once the blocking task is done executing, we will attempt to
- // steal the core back.
- let worker = cx.worker.clone();
- runtime::spawn_blocking(move || run(worker));
- });
+ // The parker should be set here
+ assert!(core.park.is_some());
+
+ // In order to block, the core must be sent to another thread for
+ // execution.
+ //
+ // First, move the core back into the worker's shared core slot.
+ cx.worker.core.set(core);
+
+ // Next, clone the worker handle and send it to a new thread for
+ // processing.
+ //
+ // Once the blocking task is done executing, we will attempt to
+ // steal the core back.
+ let worker = cx.worker.clone();
+ runtime::spawn_blocking(move || run(worker));
+ });
- if had_core {
- // Unset the current task's budget. Blocking sections are not
- // constrained by task budgets.
- let _reset = Reset(coop::stop());
+ if had_entered {
+ // Unset the current task's budget. Blocking sections are not
+ // constrained by task budgets.
+ let _reset = Reset(coop::stop());
- crate::runtime::enter::exit(f)
- } else if had_entered {
- crate::runtime::enter::exit(f)
- } else {
- f()
- }
+ crate::runtime::enter::exit(f)
+ } else {
+ f()
}
}
@@ -576,6 +569,8 @@ impl Core {
// Drain the queue
while self.next_local_task().is_some() {}
+
+ park.shutdown();
}
fn drain_pending_drop(&mut self, worker: &Worker) {
@@ -785,7 +780,7 @@ impl Shared {
///
/// If all workers have reached this point, the final cleanup is performed.
fn shutdown(&self, core: Box<Core>, worker: Arc<Worker>) {
- let mut workers = self.shutdown_workers.lock().unwrap();
+ let mut workers = self.shutdown_workers.lock();
workers.push((core, worker));
if workers.len() != self.remotes.len() {
diff --git a/src/runtime/time.rs b/src/runtime/time.rs
deleted file mode 100644
index c623d96..0000000
--- a/src/runtime/time.rs
+++ /dev/null
@@ -1,59 +0,0 @@
-//! Abstracts out the APIs necessary to `Runtime` for integrating the time
-//! driver. When the `time` feature flag is **not** enabled. These APIs are
-//! shells. This isolates the complexity of dealing with conditional
-//! compilation.
-
-pub(crate) use variant::*;
-
-#[cfg(feature = "time")]
-mod variant {
- use crate::park::Either;
- use crate::runtime::io;
- use crate::time::{self, driver};
-
- pub(crate) type Clock = time::Clock;
- pub(crate) type Driver = Either<driver::Driver<io::Driver>, io::Driver>;
- pub(crate) type Handle = Option<driver::Handle>;
-
- pub(crate) fn create_clock() -> Clock {
- Clock::new()
- }
-
- /// Create a new timer driver / handle pair
- pub(crate) fn create_driver(
- enable: bool,
- io_driver: io::Driver,
- clock: Clock,
- ) -> (Driver, Handle) {
- if enable {
- let driver = driver::Driver::new(io_driver, clock);
- let handle = driver.handle();
-
- (Either::A(driver), Some(handle))
- } else {
- (Either::B(io_driver), None)
- }
- }
-}
-
-#[cfg(not(feature = "time"))]
-mod variant {
- use crate::runtime::io;
-
- pub(crate) type Clock = ();
- pub(crate) type Driver = io::Driver;
- pub(crate) type Handle = ();
-
- pub(crate) fn create_clock() -> Clock {
- ()
- }
-
- /// Create a new timer driver / handle pair
- pub(crate) fn create_driver(
- _enable: bool,
- io_driver: io::Driver,
- _clock: Clock,
- ) -> (Driver, Handle) {
- (io_driver, ())
- }
-}
diff --git a/src/signal/registry.rs b/src/signal/registry.rs
index 50edd2b..5d6f608 100644
--- a/src/signal/registry.rs
+++ b/src/signal/registry.rs
@@ -185,7 +185,7 @@ mod tests {
#[test]
fn smoke() {
- let mut rt = rt();
+ let rt = rt();
rt.block_on(async move {
let registry = Registry::new(vec![
EventInfo::default(),
@@ -247,7 +247,7 @@ mod tests {
#[test]
fn broadcast_cleans_up_disconnected_listeners() {
- let mut rt = Runtime::new().unwrap();
+ let rt = Runtime::new().unwrap();
rt.block_on(async {
let registry = Registry::new(vec![EventInfo::default()]);
@@ -306,7 +306,7 @@ mod tests {
}
fn rt() -> Runtime {
- runtime::Builder::new().basic_scheduler().build().unwrap()
+ runtime::Builder::new_current_thread().build().unwrap()
}
async fn collect(mut rx: crate::sync::mpsc::Receiver<()>) -> Vec<()> {
diff --git a/src/signal/unix.rs b/src/signal/unix.rs
index b46b15c..aaaa75e 100644
--- a/src/signal/unix.rs
+++ b/src/signal/unix.rs
@@ -5,18 +5,21 @@
#![cfg(unix)]
-use crate::io::{AsyncRead, PollEvented};
use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage};
+use crate::sync::mpsc::error::TryRecvError;
use crate::sync::mpsc::{channel, Receiver};
use libc::c_int;
-use mio_uds::UnixStream;
+use mio::net::UnixStream;
use std::io::{self, Error, ErrorKind, Write};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Once;
use std::task::{Context, Poll};
+pub(crate) mod driver;
+use self::driver::Handle;
+
pub(crate) type OsStorage = Vec<SignalInfo>;
// Number of different unix signals
@@ -202,9 +205,9 @@ impl Default for SignalInfo {
/// The purpose of this signal handler is to primarily:
///
/// 1. Flag that our specific signal was received (e.g. store an atomic flag)
-/// 2. Wake up driver tasks by writing a byte to a pipe
+/// 2. Wake up the driver by writing a byte to a pipe
///
-/// Those two operations shoudl both be async-signal safe.
+/// Those two operations should both be async-signal safe.
fn action(globals: Pin<&'static Globals>, signal: c_int) {
globals.record_event(signal as EventId);
@@ -219,7 +222,7 @@ fn action(globals: Pin<&'static Globals>, signal: c_int) {
///
/// This will register the signal handler if it hasn't already been registered,
/// returning any error along the way if that fails.
-fn signal_enable(signal: c_int) -> io::Result<()> {
+fn signal_enable(signal: c_int, handle: Handle) -> io::Result<()> {
if signal < 0 || signal_hook_registry::FORBIDDEN.contains(&signal) {
return Err(Error::new(
ErrorKind::Other,
@@ -227,6 +230,9 @@ fn signal_enable(signal: c_int) -> io::Result<()> {
));
}
+ // Check that we have a signal driver running
+ handle.check_inner()?;
+
let globals = globals();
let siginfo = match globals.storage().get(signal as EventId) {
Some(slot) => slot,
@@ -254,63 +260,6 @@ fn signal_enable(signal: c_int) -> io::Result<()> {
}
}
-#[derive(Debug)]
-struct Driver {
- wakeup: PollEvented<UnixStream>,
-}
-
-impl Driver {
- fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
- // Drain the data from the pipe and maintain interest in getting more
- self.drain(cx);
- // Broadcast any signals which were received
- globals().broadcast();
-
- Poll::Pending
- }
-}
-
-impl Driver {
- fn new() -> io::Result<Driver> {
- // NB: We give each driver a "fresh" reciever file descriptor to avoid
- // the issues described in alexcrichton/tokio-process#42.
- //
- // In the past we would reuse the actual receiver file descriptor and
- // swallow any errors around double registration of the same descriptor.
- // I'm not sure if the second (failed) registration simply doesn't end up
- // receiving wake up notifications, or there could be some race condition
- // when consuming readiness events, but having distinct descriptors for
- // distinct PollEvented instances appears to mitigate this.
- //
- // Unfortunately we cannot just use a single global PollEvented instance
- // either, since we can't compare Handles or assume they will always
- // point to the exact same reactor.
- let stream = globals().receiver.try_clone()?;
- let wakeup = PollEvented::new(stream)?;
-
- Ok(Driver { wakeup })
- }
-
- /// Drain all data in the global receiver, ensuring we'll get woken up when
- /// there is a write on the other end.
- ///
- /// We do *NOT* use the existence of any read bytes as evidence a signal was
- /// received since the `pending` flags would have already been set if that
- /// was the case. See
- /// [#38](https://github.com/alexcrichton/tokio-signal/issues/38) for more
- /// info.
- fn drain(&mut self, cx: &mut Context<'_>) {
- loop {
- match Pin::new(&mut self.wakeup).poll_read(cx, &mut [0; 128]) {
- Poll::Ready(Ok(0)) => panic!("EOF on self-pipe"),
- Poll::Ready(Ok(_)) => {}
- Poll::Ready(Err(e)) => panic!("Bad read on self-pipe: {}", e),
- Poll::Pending => break,
- }
- }
- }
-}
-
/// A stream of events for receiving a particular type of OS signal.
///
/// In general signal handling on Unix is a pretty tricky topic, and this
@@ -376,7 +325,6 @@ impl Driver {
#[must_use = "streams do nothing unless polled"]
#[derive(Debug)]
pub struct Signal {
- driver: Driver,
rx: Receiver<()>,
}
@@ -403,21 +351,21 @@ pub struct Signal {
/// * If the signal is one of
/// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics)
pub fn signal(kind: SignalKind) -> io::Result<Signal> {
+ signal_with_handle(kind, Handle::current())
+}
+
+pub(crate) fn signal_with_handle(kind: SignalKind, handle: Handle) -> io::Result<Signal> {
let signal = kind.0;
// Turn the signal delivery on once we are ready for it
- signal_enable(signal)?;
-
- // Ensure there's a driver for our associated event loop processing
- // signals.
- let driver = Driver::new()?;
+ signal_enable(signal, handle)?;
// One wakeup in a queue is enough, no need for us to buffer up any
// more.
let (tx, rx) = channel(1);
globals().register_listener(signal as EventId, tx);
- Ok(Signal { driver, rx })
+ Ok(Signal { rx })
}
impl Signal {
@@ -449,38 +397,14 @@ impl Signal {
poll_fn(|cx| self.poll_recv(cx)).await
}
- /// Polls to receive the next signal notification event, outside of an
- /// `async` context.
- ///
- /// `None` is returned if no more events can be received by this stream.
- ///
- /// # Examples
- ///
- /// Polling from a manually implemented future
- ///
- /// ```rust,no_run
- /// use std::pin::Pin;
- /// use std::future::Future;
- /// use std::task::{Context, Poll};
- /// use tokio::signal::unix::Signal;
- ///
- /// struct MyFuture {
- /// signal: Signal,
- /// }
- ///
- /// impl Future for MyFuture {
- /// type Output = Option<()>;
- ///
- /// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- /// println!("polling MyFuture");
- /// self.signal.poll_recv(cx)
- /// }
- /// }
- /// ```
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
- let _ = self.driver.poll(cx);
+ pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
self.rx.poll_recv(cx)
}
+
+ /// Try to receive a signal notification without blocking or registering a waker.
+ pub(crate) fn try_recv(&mut self) -> Result<(), TryRecvError> {
+ self.rx.try_recv()
+ }
}
cfg_stream! {
@@ -493,6 +417,22 @@ cfg_stream! {
}
}
+// Work around for abstracting streams internally
+pub(crate) trait InternalStream: Unpin {
+ fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>;
+ fn try_recv(&mut self) -> Result<(), TryRecvError>;
+}
+
+impl InternalStream for Signal {
+ fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
+ self.poll_recv(cx)
+ }
+
+ fn try_recv(&mut self) -> Result<(), TryRecvError> {
+ self.try_recv()
+ }
+}
+
pub(crate) fn ctrl_c() -> io::Result<Signal> {
signal(SignalKind::interrupt())
}
@@ -503,11 +443,11 @@ mod tests {
#[test]
fn signal_enable_error_on_invalid_input() {
- signal_enable(-1).unwrap_err();
+ signal_enable(-1, Handle::default()).unwrap_err();
}
#[test]
fn signal_enable_error_on_forbidden_input() {
- signal_enable(signal_hook_registry::FORBIDDEN[0]).unwrap_err();
+ signal_enable(signal_hook_registry::FORBIDDEN[0], Handle::default()).unwrap_err();
}
}
diff --git a/src/signal/unix/driver.rs b/src/signal/unix/driver.rs
new file mode 100644
index 0000000..8e5ed7d
--- /dev/null
+++ b/src/signal/unix/driver.rs
@@ -0,0 +1,207 @@
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
+//! Signal driver
+
+use crate::io::driver::Driver as IoDriver;
+use crate::io::PollEvented;
+use crate::park::Park;
+use crate::signal::registry::globals;
+
+use mio::net::UnixStream;
+use std::io::{self, Read};
+use std::ptr;
+use std::sync::{Arc, Weak};
+use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
+use std::time::Duration;
+
+/// Responsible for registering wakeups when an OS signal is received, and
+/// subsequently dispatching notifications to any signal listeners as appropriate.
+///
+/// Note: this driver relies on having an enabled IO driver in order to listen to
+/// pipe write wakeups.
+#[derive(Debug)]
+pub(crate) struct Driver {
+ /// Thread parker. The `Driver` park implementation delegates to this.
+ park: IoDriver,
+
+ /// A pipe for receiving wake events from the signal handler
+ receiver: PollEvented<UnixStream>,
+
+ /// Shared state
+ inner: Arc<Inner>,
+}
+
+#[derive(Clone, Debug, Default)]
+pub(crate) struct Handle {
+ inner: Weak<Inner>,
+}
+
+#[derive(Debug)]
+pub(super) struct Inner(());
+
+// ===== impl Driver =====
+
+impl Driver {
+ /// Creates a new signal `Driver` instance that delegates wakeups to `park`.
+ pub(crate) fn new(park: IoDriver) -> io::Result<Self> {
+ use std::mem::ManuallyDrop;
+ use std::os::unix::io::{AsRawFd, FromRawFd};
+
+ // NB: We give each driver a "fresh" reciever file descriptor to avoid
+ // the issues described in alexcrichton/tokio-process#42.
+ //
+ // In the past we would reuse the actual receiver file descriptor and
+ // swallow any errors around double registration of the same descriptor.
+ // I'm not sure if the second (failed) registration simply doesn't end
+ // up receiving wake up notifications, or there could be some race
+ // condition when consuming readiness events, but having distinct
+ // descriptors for distinct PollEvented instances appears to mitigate
+ // this.
+ //
+ // Unfortunately we cannot just use a single global PollEvented instance
+ // either, since we can't compare Handles or assume they will always
+ // point to the exact same reactor.
+ //
+ // Mio 0.7 removed `try_clone()` as an API due to unexpected behavior
+ // with registering dups with the same reactor. In this case, duping is
+ // safe as each dup is registered with separate reactors **and** we
+ // only expect at least one dup to receive the notification.
+
+ // Manually drop as we don't actually own this instance of UnixStream.
+ let receiver_fd = globals().receiver.as_raw_fd();
+
+ // safety: there is nothing unsafe about this, but the `from_raw_fd` fn is marked as unsafe.
+ let original =
+ ManuallyDrop::new(unsafe { std::os::unix::net::UnixStream::from_raw_fd(receiver_fd) });
+ let receiver = UnixStream::from_std(original.try_clone()?);
+ let receiver = PollEvented::new_with_interest_and_handle(
+ receiver,
+ mio::Interest::READABLE | mio::Interest::WRITABLE,
+ park.handle(),
+ )?;
+
+ Ok(Self {
+ park,
+ receiver,
+ inner: Arc::new(Inner(())),
+ })
+ }
+
+ /// Returns a handle to this event loop which can be sent across threads
+ /// and can be used as a proxy to the event loop itself.
+ pub(crate) fn handle(&self) -> Handle {
+ Handle {
+ inner: Arc::downgrade(&self.inner),
+ }
+ }
+
+ fn process(&self) {
+ // Check if the pipe is ready to read and therefore has "woken" us up
+ //
+ // To do so, we will `poll_read_ready` with a noop waker, since we don't
+ // need to actually be notified when read ready...
+ let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)) };
+ let mut cx = Context::from_waker(&waker);
+
+ let ev = match self.receiver.poll_read_ready(&mut cx) {
+ Poll::Ready(Ok(ev)) => ev,
+ Poll::Ready(Err(e)) => panic!("reactor gone: {}", e),
+ Poll::Pending => return, // No wake has arrived, bail
+ };
+
+ // Drain the pipe completely so we can receive a new readiness event
+ // if another signal has come in.
+ let mut buf = [0; 128];
+ loop {
+ match self.receiver.get_ref().read(&mut buf) {
+ Ok(0) => panic!("EOF on self-pipe"),
+ Ok(_) => continue, // Keep reading
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => break,
+ Err(e) => panic!("Bad read on self-pipe: {}", e),
+ }
+ }
+
+ self.receiver.clear_readiness(ev);
+
+ // Broadcast any signals which were received
+ globals().broadcast();
+ }
+}
+
+const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
+
+unsafe fn noop_clone(_data: *const ()) -> RawWaker {
+ RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE)
+}
+
+unsafe fn noop(_data: *const ()) {}
+
+// ===== impl Park for Driver =====
+
+impl Park for Driver {
+ type Unpark = <IoDriver as Park>::Unpark;
+ type Error = io::Error;
+
+ fn unpark(&self) -> Self::Unpark {
+ self.park.unpark()
+ }
+
+ fn park(&mut self) -> Result<(), Self::Error> {
+ self.park.park()?;
+ self.process();
+ Ok(())
+ }
+
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> {
+ self.park.park_timeout(duration)?;
+ self.process();
+ Ok(())
+ }
+
+ fn shutdown(&mut self) {
+ self.park.shutdown()
+ }
+}
+
+// ===== impl Handle =====
+
+impl Handle {
+ pub(super) fn check_inner(&self) -> io::Result<()> {
+ if self.inner.strong_count() > 0 {
+ Ok(())
+ } else {
+ Err(io::Error::new(io::ErrorKind::Other, "signal driver gone"))
+ }
+ }
+}
+
+cfg_rt! {
+ impl Handle {
+ /// Returns a handle to the current driver
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current signal driver set.
+ pub(super) fn current() -> Self {
+ crate::runtime::context::signal_handle().expect(
+ "there is no signal driver running, must be called from the context of Tokio runtime",
+ )
+ }
+ }
+}
+
+cfg_not_rt! {
+ impl Handle {
+ /// Returns a handle to the current driver
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current signal driver set.
+ pub(super) fn current() -> Self {
+ panic!(
+ "there is no signal driver running, must be called from the context of Tokio runtime or with\
+ `rt` enabled.",
+ )
+ }
+ }
+}
diff --git a/src/signal/windows.rs b/src/signal/windows.rs
index f55e504..1e78362 100644
--- a/src/signal/windows.rs
+++ b/src/signal/windows.rs
@@ -14,9 +14,9 @@ use std::convert::TryFrom;
use std::io;
use std::sync::Once;
use std::task::{Context, Poll};
-use winapi::shared::minwindef::*;
+use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE};
use winapi::um::consoleapi::SetConsoleCtrlHandler;
-use winapi::um::wincon::*;
+use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT};
#[derive(Debug)]
pub(crate) struct OsStorage {
@@ -253,26 +253,25 @@ mod tests {
#[test]
fn ctrl_c() {
let rt = rt();
+ let _enter = rt.enter();
- rt.enter(|| {
- let mut ctrl_c = task::spawn(crate::signal::ctrl_c());
+ let mut ctrl_c = task::spawn(crate::signal::ctrl_c());
- assert_pending!(ctrl_c.poll());
+ assert_pending!(ctrl_c.poll());
- // Windows doesn't have a good programmatic way of sending events
- // like sending signals on Unix, so we'll stub out the actual OS
- // integration and test that our handling works.
- unsafe {
- super::handler(CTRL_C_EVENT);
- }
+ // Windows doesn't have a good programmatic way of sending events
+ // like sending signals on Unix, so we'll stub out the actual OS
+ // integration and test that our handling works.
+ unsafe {
+ super::handler(CTRL_C_EVENT);
+ }
- assert_ready_ok!(ctrl_c.poll());
- });
+ assert_ready_ok!(ctrl_c.poll());
}
#[test]
fn ctrl_break() {
- let mut rt = rt();
+ let rt = rt();
rt.block_on(async {
let mut ctrl_break = assert_ok!(super::ctrl_break());
@@ -289,8 +288,7 @@ mod tests {
}
fn rt() -> Runtime {
- crate::runtime::Builder::new()
- .basic_scheduler()
+ crate::runtime::Builder::new_current_thread()
.build()
.unwrap()
}
diff --git a/src/stream/all.rs b/src/stream/all.rs
index 615665d..353d61a 100644
--- a/src/stream/all.rs
+++ b/src/stream/all.rs
@@ -1,25 +1,34 @@
use crate::stream::Stream;
use core::future::Future;
+use core::marker::PhantomPinned;
use core::pin::Pin;
use core::task::{Context, Poll};
+use pin_project_lite::pin_project;
-/// Future for the [`all`](super::StreamExt::all) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct AllFuture<'a, St: ?Sized, F> {
- stream: &'a mut St,
- f: F,
+pin_project! {
+ /// Future for the [`all`](super::StreamExt::all) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct AllFuture<'a, St: ?Sized, F> {
+ stream: &'a mut St,
+ f: F,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
impl<'a, St: ?Sized, F> AllFuture<'a, St, F> {
pub(super) fn new(stream: &'a mut St, f: F) -> Self {
- Self { stream, f }
+ Self {
+ stream,
+ f,
+ _pin: PhantomPinned,
+ }
}
}
-impl<St: ?Sized + Unpin, F> Unpin for AllFuture<'_, St, F> {}
-
impl<St, F> Future for AllFuture<'_, St, F>
where
St: ?Sized + Stream + Unpin,
@@ -27,12 +36,13 @@ where
{
type Output = bool;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let next = futures_core::ready!(Pin::new(&mut self.stream).poll_next(cx));
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx));
match next {
Some(v) => {
- if !(&mut self.f)(v) {
+ if !(me.f)(v) {
Poll::Ready(false)
} else {
cx.waker().wake_by_ref();
diff --git a/src/stream/any.rs b/src/stream/any.rs
index f2ecad5..aac0ec7 100644
--- a/src/stream/any.rs
+++ b/src/stream/any.rs
@@ -1,25 +1,34 @@
use crate::stream::Stream;
use core::future::Future;
+use core::marker::PhantomPinned;
use core::pin::Pin;
use core::task::{Context, Poll};
+use pin_project_lite::pin_project;
-/// Future for the [`any`](super::StreamExt::any) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct AnyFuture<'a, St: ?Sized, F> {
- stream: &'a mut St,
- f: F,
+pin_project! {
+ /// Future for the [`any`](super::StreamExt::any) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct AnyFuture<'a, St: ?Sized, F> {
+ stream: &'a mut St,
+ f: F,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
impl<'a, St: ?Sized, F> AnyFuture<'a, St, F> {
pub(super) fn new(stream: &'a mut St, f: F) -> Self {
- Self { stream, f }
+ Self {
+ stream,
+ f,
+ _pin: PhantomPinned,
+ }
}
}
-impl<St: ?Sized + Unpin, F> Unpin for AnyFuture<'_, St, F> {}
-
impl<St, F> Future for AnyFuture<'_, St, F>
where
St: ?Sized + Stream + Unpin,
@@ -27,12 +36,13 @@ where
{
type Output = bool;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let next = futures_core::ready!(Pin::new(&mut self.stream).poll_next(cx));
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ let next = futures_core::ready!(Pin::new(me.stream).poll_next(cx));
match next {
Some(v) => {
- if (&mut self.f)(v) {
+ if (me.f)(v) {
Poll::Ready(true)
} else {
cx.waker().wake_by_ref();
diff --git a/src/stream/collect.rs b/src/stream/collect.rs
index 4649428..1aafc30 100644
--- a/src/stream/collect.rs
+++ b/src/stream/collect.rs
@@ -1,7 +1,7 @@
use crate::stream::Stream;
-use bytes::{Buf, BufMut, Bytes, BytesMut};
use core::future::Future;
+use core::marker::PhantomPinned;
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};
@@ -10,7 +10,7 @@ use pin_project_lite::pin_project;
// Do not export this struct until `FromStream` can be unsealed.
pin_project! {
/// Future returned by the [`collect`](super::StreamExt::collect) method.
- #[must_use = "streams do nothing unless polled"]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Collect<T, U>
where
@@ -19,7 +19,10 @@ pin_project! {
{
#[pin]
stream: T,
- collection: U::Collection,
+ collection: U::InternalCollection,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -42,9 +45,13 @@ where
{
pub(super) fn new(stream: T) -> Collect<T, U> {
let (lower, upper) = stream.size_hint();
- let collection = U::initialize(lower, upper);
+ let collection = U::initialize(sealed::Internal, lower, upper);
- Collect { stream, collection }
+ Collect {
+ stream,
+ collection,
+ _pin: PhantomPinned,
+ }
}
}
@@ -64,12 +71,12 @@ where
let item = match ready!(me.stream.poll_next(cx)) {
Some(item) => item,
None => {
- return Ready(U::finalize(&mut me.collection));
+ return Ready(U::finalize(sealed::Internal, &mut me.collection));
}
};
- if !U::extend(&mut me.collection, item) {
- return Ready(U::finalize(&mut me.collection));
+ if !U::extend(sealed::Internal, &mut me.collection, item) {
+ return Ready(U::finalize(sealed::Internal, &mut me.collection));
}
}
}
@@ -80,32 +87,32 @@ where
impl FromStream<()> for () {}
impl sealed::FromStreamPriv<()> for () {
- type Collection = ();
+ type InternalCollection = ();
- fn initialize(_lower: usize, _upper: Option<usize>) {}
+ fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) {}
- fn extend(_collection: &mut (), _item: ()) -> bool {
+ fn extend(_: sealed::Internal, _collection: &mut (), _item: ()) -> bool {
true
}
- fn finalize(_collection: &mut ()) {}
+ fn finalize(_: sealed::Internal, _collection: &mut ()) {}
}
impl<T: AsRef<str>> FromStream<T> for String {}
impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String {
- type Collection = String;
+ type InternalCollection = String;
- fn initialize(_lower: usize, _upper: Option<usize>) -> String {
+ fn initialize(_: sealed::Internal, _lower: usize, _upper: Option<usize>) -> String {
String::new()
}
- fn extend(collection: &mut String, item: T) -> bool {
+ fn extend(_: sealed::Internal, collection: &mut String, item: T) -> bool {
collection.push_str(item.as_ref());
true
}
- fn finalize(collection: &mut String) -> String {
+ fn finalize(_: sealed::Internal, collection: &mut String) -> String {
mem::replace(collection, String::new())
}
}
@@ -113,18 +120,18 @@ impl<T: AsRef<str>> sealed::FromStreamPriv<T> for String {
impl<T> FromStream<T> for Vec<T> {}
impl<T> sealed::FromStreamPriv<T> for Vec<T> {
- type Collection = Vec<T>;
+ type InternalCollection = Vec<T>;
- fn initialize(lower: usize, _upper: Option<usize>) -> Vec<T> {
+ fn initialize(_: sealed::Internal, lower: usize, _upper: Option<usize>) -> Vec<T> {
Vec::with_capacity(lower)
}
- fn extend(collection: &mut Vec<T>, item: T) -> bool {
+ fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool {
collection.push(item);
true
}
- fn finalize(collection: &mut Vec<T>) -> Vec<T> {
+ fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Vec<T> {
mem::replace(collection, vec![])
}
}
@@ -132,18 +139,19 @@ impl<T> sealed::FromStreamPriv<T> for Vec<T> {
impl<T> FromStream<T> for Box<[T]> {}
impl<T> sealed::FromStreamPriv<T> for Box<[T]> {
- type Collection = Vec<T>;
+ type InternalCollection = Vec<T>;
- fn initialize(lower: usize, upper: Option<usize>) -> Vec<T> {
- <Vec<T> as sealed::FromStreamPriv<T>>::initialize(lower, upper)
+ fn initialize(_: sealed::Internal, lower: usize, upper: Option<usize>) -> Vec<T> {
+ <Vec<T> as sealed::FromStreamPriv<T>>::initialize(sealed::Internal, lower, upper)
}
- fn extend(collection: &mut Vec<T>, item: T) -> bool {
- <Vec<T> as sealed::FromStreamPriv<T>>::extend(collection, item)
+ fn extend(_: sealed::Internal, collection: &mut Vec<T>, item: T) -> bool {
+ <Vec<T> as sealed::FromStreamPriv<T>>::extend(sealed::Internal, collection, item)
}
- fn finalize(collection: &mut Vec<T>) -> Box<[T]> {
- <Vec<T> as sealed::FromStreamPriv<T>>::finalize(collection).into_boxed_slice()
+ fn finalize(_: sealed::Internal, collection: &mut Vec<T>) -> Box<[T]> {
+ <Vec<T> as sealed::FromStreamPriv<T>>::finalize(sealed::Internal, collection)
+ .into_boxed_slice()
}
}
@@ -153,18 +161,26 @@ impl<T, U, E> sealed::FromStreamPriv<Result<T, E>> for Result<U, E>
where
U: FromStream<T>,
{
- type Collection = Result<U::Collection, E>;
+ type InternalCollection = Result<U::InternalCollection, E>;
- fn initialize(lower: usize, upper: Option<usize>) -> Result<U::Collection, E> {
- Ok(U::initialize(lower, upper))
+ fn initialize(
+ _: sealed::Internal,
+ lower: usize,
+ upper: Option<usize>,
+ ) -> Result<U::InternalCollection, E> {
+ Ok(U::initialize(sealed::Internal, lower, upper))
}
- fn extend(collection: &mut Self::Collection, item: Result<T, E>) -> bool {
+ fn extend(
+ _: sealed::Internal,
+ collection: &mut Self::InternalCollection,
+ item: Result<T, E>,
+ ) -> bool {
assert!(collection.is_ok());
match item {
Ok(item) => {
let collection = collection.as_mut().ok().expect("invalid state");
- U::extend(collection, item)
+ U::extend(sealed::Internal, collection, item)
}
Err(err) => {
*collection = Err(err);
@@ -173,11 +189,11 @@ where
}
}
- fn finalize(collection: &mut Self::Collection) -> Result<U, E> {
+ fn finalize(_: sealed::Internal, collection: &mut Self::InternalCollection) -> Result<U, E> {
if let Ok(collection) = collection.as_mut() {
- Ok(U::finalize(collection))
+ Ok(U::finalize(sealed::Internal, collection))
} else {
- let res = mem::replace(collection, Ok(U::initialize(0, Some(0))));
+ let res = mem::replace(collection, Ok(U::initialize(sealed::Internal, 0, Some(0))));
if let Err(err) = res {
Err(err)
@@ -188,59 +204,30 @@ where
}
}
-impl<T: Buf> FromStream<T> for Bytes {}
-
-impl<T: Buf> sealed::FromStreamPriv<T> for Bytes {
- type Collection = BytesMut;
-
- fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut {
- BytesMut::new()
- }
-
- fn extend(collection: &mut BytesMut, item: T) -> bool {
- collection.put(item);
- true
- }
-
- fn finalize(collection: &mut BytesMut) -> Bytes {
- mem::replace(collection, BytesMut::new()).freeze()
- }
-}
-
-impl<T: Buf> FromStream<T> for BytesMut {}
-
-impl<T: Buf> sealed::FromStreamPriv<T> for BytesMut {
- type Collection = BytesMut;
-
- fn initialize(_lower: usize, _upper: Option<usize>) -> BytesMut {
- BytesMut::new()
- }
-
- fn extend(collection: &mut BytesMut, item: T) -> bool {
- collection.put(item);
- true
- }
-
- fn finalize(collection: &mut BytesMut) -> BytesMut {
- mem::replace(collection, BytesMut::new())
- }
-}
-
pub(crate) mod sealed {
#[doc(hidden)]
pub trait FromStreamPriv<T> {
/// Intermediate type used during collection process
- type Collection;
+ ///
+ /// The name of this type is internal and cannot be relied upon.
+ type InternalCollection;
/// Initialize the collection
- fn initialize(lower: usize, upper: Option<usize>) -> Self::Collection;
+ fn initialize(
+ internal: Internal,
+ lower: usize,
+ upper: Option<usize>,
+ ) -> Self::InternalCollection;
/// Extend the collection with the received item
///
/// Return `true` to continue streaming, `false` complete collection.
- fn extend(collection: &mut Self::Collection, item: T) -> bool;
+ fn extend(internal: Internal, collection: &mut Self::InternalCollection, item: T) -> bool;
/// Finalize collection into target type.
- fn finalize(collection: &mut Self::Collection) -> Self;
+ fn finalize(internal: Internal, collection: &mut Self::InternalCollection) -> Self;
}
+
+ #[allow(missing_debug_implementations)]
+ pub struct Internal;
}
diff --git a/src/stream/fold.rs b/src/stream/fold.rs
index 7b9fead..5cf2bfa 100644
--- a/src/stream/fold.rs
+++ b/src/stream/fold.rs
@@ -1,6 +1,7 @@
use crate::stream::Stream;
use core::future::Future;
+use core::marker::PhantomPinned;
use core::pin::Pin;
use core::task::{Context, Poll};
use pin_project_lite::pin_project;
@@ -8,11 +9,15 @@ use pin_project_lite::pin_project;
pin_project! {
/// Future returned by the [`fold`](super::StreamExt::fold) method.
#[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct FoldFuture<St, B, F> {
#[pin]
stream: St,
acc: Option<B>,
f: F,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -22,6 +27,7 @@ impl<St, B, F> FoldFuture<St, B, F> {
stream,
acc: Some(init),
f,
+ _pin: PhantomPinned,
}
}
}
diff --git a/src/stream/mod.rs b/src/stream/mod.rs
index 7b061ef..6bf4232 100644
--- a/src/stream/mod.rs
+++ b/src/stream/mod.rs
@@ -1,8 +1,57 @@
//! Stream utilities for Tokio.
//!
-//! A `Stream` is an asynchronous sequence of values. It can be thought of as an asynchronous version of the standard library's `Iterator` trait.
+//! A `Stream` is an asynchronous sequence of values. It can be thought of as
+//! an asynchronous version of the standard library's `Iterator` trait.
//!
-//! This module provides helpers to work with them.
+//! This module provides helpers to work with them. For examples of usage and a more in-depth
+//! description of streams you can also refer to the [streams
+//! tutorial](https://tokio.rs/tokio/tutorial/streams) on the tokio website.
+//!
+//! # Iterating over a Stream
+//!
+//! Due to similarities with the standard library's `Iterator` trait, some new
+//! users may assume that they can use `for in` syntax to iterate over a
+//! `Stream`, but this is unfortunately not possible. Instead, you can use a
+//! `while let` loop as follows:
+//!
+//! ```rust
+//! use tokio::stream::{self, StreamExt};
+//!
+//! #[tokio::main]
+//! async fn main() {
+//! let mut stream = stream::iter(vec![0, 1, 2]);
+//!
+//! while let Some(value) = stream.next().await {
+//! println!("Got {}", value);
+//! }
+//! }
+//! ```
+//!
+//! # Returning a Stream from a function
+//!
+//! A common way to stream values from a function is to pass in the sender
+//! half of a channel and use the receiver as the stream. This requires awaiting
+//! both futures to ensure progress is made. Another alternative is the
+//! [async-stream] crate, which contains macros that provide a `yield` keyword
+//! and allow you to return an `impl Stream`.
+//!
+//! [async-stream]: https://docs.rs/async-stream
+//!
+//! # Conversion to and from AsyncRead/AsyncWrite
+//!
+//! It is often desirable to convert a `Stream` into an [`AsyncRead`],
+//! especially when dealing with plaintext formats streamed over the network.
+//! The opposite conversion from an [`AsyncRead`] into a `Stream` is also
+//! another commonly required feature. To enable these conversions,
+//! [`tokio-util`] provides the [`StreamReader`] and [`ReaderStream`]
+//! types when the io feature is enabled.
+//!
+//! [tokio-util]: https://docs.rs/tokio-util/0.3/tokio_util/codec/index.html
+//! [`tokio::io`]: crate::io
+//! [`AsyncRead`]: crate::io::AsyncRead
+//! [`AsyncWrite`]: crate::io::AsyncWrite
+//! [`ReaderStream`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.ReaderStream.html
+//! [`StreamReader`]: https://docs.rs/tokio-util/0.4/tokio_util/io/struct.StreamReader.html
mod all;
use all::AllFuture;
@@ -71,9 +120,12 @@ use take_while::TakeWhile;
cfg_time! {
mod timeout;
use timeout::Timeout;
- use std::time::Duration;
+ use crate::time::Duration;
+ mod throttle;
+ use crate::stream::throttle::{throttle, Throttle};
}
+#[doc(no_inline)]
pub use futures_core::Stream;
/// An extension trait for `Stream`s that provides a variety of convenient
@@ -215,11 +267,11 @@ pub trait StreamExt: Stream {
/// # /*
/// #[tokio::main]
/// # */
- /// # #[tokio::main(basic_scheduler)]
+ /// # #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// # time::pause();
- /// let (mut tx1, rx1) = mpsc::channel(10);
- /// let (mut tx2, rx2) = mpsc::channel(10);
+ /// let (tx1, rx1) = mpsc::channel(10);
+ /// let (tx2, rx2) = mpsc::channel(10);
///
/// let mut rx = rx1.merge(rx2);
///
@@ -229,18 +281,18 @@ pub trait StreamExt: Stream {
/// tx1.send(2).await.unwrap();
///
/// // Let the other task send values
- /// time::delay_for(Duration::from_millis(20)).await;
+ /// time::sleep(Duration::from_millis(20)).await;
///
/// tx1.send(4).await.unwrap();
/// });
///
/// tokio::spawn(async move {
/// // Wait for the first task to send values
- /// time::delay_for(Duration::from_millis(5)).await;
+ /// time::sleep(Duration::from_millis(5)).await;
///
/// tx2.send(3).await.unwrap();
///
- /// time::delay_for(Duration::from_millis(25)).await;
+ /// time::sleep(Duration::from_millis(25)).await;
///
/// // Send the final value
/// tx2.send(5).await.unwrap();
@@ -520,6 +572,12 @@ pub trait StreamExt: Stream {
/// Tests if every element of the stream matches a predicate.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn all<F>(&mut self, f: F) -> bool;
+ /// ```
+ ///
/// `all()` takes a closure that returns `true` or `false`. It applies
/// this closure to each element of the stream, and if they all return
/// `true`, then so does `all`. If any of them return `false`, it
@@ -575,6 +633,12 @@ pub trait StreamExt: Stream {
/// Tests if any element of the stream matches a predicate.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn any<F>(&mut self, f: F) -> bool;
+ /// ```
+ ///
/// `any()` takes a closure that returns `true` or `false`. It applies
/// this closure to each element of the stream, and if any of them return
/// `true`, then so does `any()`. If they all return `false`, it
@@ -664,6 +728,12 @@ pub trait StreamExt: Stream {
/// A combinator that applies a function to every element in a stream
/// producing a single, final value.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn fold<B, F>(self, init: B, f: F) -> B;
+ /// ```
+ ///
/// # Examples
/// Basic usage:
/// ```
@@ -687,6 +757,12 @@ pub trait StreamExt: Stream {
/// Drain stream pushing all emitted values into a collection.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn collect<T>(self) -> T;
+ /// ```
+ ///
/// `collect` streams all values, awaiting as needed. Values are pushed into
/// a collection. A number of different target collection types are
/// supported, including [`Vec`](std::vec::Vec),
@@ -819,6 +895,33 @@ pub trait StreamExt: Stream {
{
Timeout::new(self, duration)
}
+
+ /// Slows down a stream by enforcing a delay between items.
+ ///
+ /// # Example
+ ///
+ /// Create a throttled stream.
+ /// ```rust,no_run
+ /// use std::time::Duration;
+ /// use tokio::stream::StreamExt;
+ ///
+ /// # async fn dox() {
+ /// let mut item_stream = futures::stream::repeat("one").throttle(Duration::from_secs(2));
+ ///
+ /// loop {
+ /// // The string will be produced at most every 2 seconds
+ /// println!("{:?}", item_stream.next().await);
+ /// }
+ /// # }
+ /// ```
+ #[cfg(all(feature = "time"))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "time")))]
+ fn throttle(self, duration: Duration) -> Throttle<Self>
+ where
+ Self: Sized,
+ {
+ throttle(duration, self)
+ }
}
impl<St: ?Sized> StreamExt for St where St: Stream {}
diff --git a/src/stream/next.rs b/src/stream/next.rs
index 3909c0c..d9b1f92 100644
--- a/src/stream/next.rs
+++ b/src/stream/next.rs
@@ -1,28 +1,37 @@
use crate::stream::Stream;
use core::future::Future;
+use core::marker::PhantomPinned;
use core::pin::Pin;
use core::task::{Context, Poll};
+use pin_project_lite::pin_project;
-/// Future for the [`next`](super::StreamExt::next) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct Next<'a, St: ?Sized> {
- stream: &'a mut St,
+pin_project! {
+ /// Future for the [`next`](super::StreamExt::next) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct Next<'a, St: ?Sized> {
+ stream: &'a mut St,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
-impl<St: ?Sized + Unpin> Unpin for Next<'_, St> {}
-
impl<'a, St: ?Sized> Next<'a, St> {
pub(super) fn new(stream: &'a mut St) -> Self {
- Next { stream }
+ Next {
+ stream,
+ _pin: PhantomPinned,
+ }
}
}
impl<St: ?Sized + Stream + Unpin> Future for Next<'_, St> {
type Output = Option<St::Item>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- Pin::new(&mut self.stream).poll_next(cx)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ Pin::new(me.stream).poll_next(cx)
}
}
diff --git a/src/stream/stream_map.rs b/src/stream/stream_map.rs
index 2f60ea4..8539e4d 100644
--- a/src/stream/stream_map.rs
+++ b/src/stream/stream_map.rs
@@ -57,8 +57,8 @@ use std::task::{Context, Poll};
///
/// #[tokio::main]
/// async fn main() {
-/// let (mut tx1, rx1) = mpsc::channel(10);
-/// let (mut tx2, rx2) = mpsc::channel(10);
+/// let (tx1, rx1) = mpsc::channel(10);
+/// let (tx2, rx2) = mpsc::channel(10);
///
/// tokio::spawn(async move {
/// tx1.send(1).await.unwrap();
@@ -163,6 +163,52 @@ pub struct StreamMap<K, V> {
}
impl<K, V> StreamMap<K, V> {
+ /// An iterator visiting all key-value pairs in arbitrary order.
+ ///
+ /// The iterator element type is &'a (K, V).
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::stream::{StreamMap, pending};
+ ///
+ /// let mut map = StreamMap::new();
+ ///
+ /// map.insert("a", pending::<i32>());
+ /// map.insert("b", pending());
+ /// map.insert("c", pending());
+ ///
+ /// for (key, stream) in map.iter() {
+ /// println!("({}, {:?})", key, stream);
+ /// }
+ /// ```
+ pub fn iter(&self) -> impl Iterator<Item = &(K, V)> {
+ self.entries.iter()
+ }
+
+ /// An iterator visiting all key-value pairs mutably in arbitrary order.
+ ///
+ /// The iterator element type is &'a mut (K, V).
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::stream::{StreamMap, pending};
+ ///
+ /// let mut map = StreamMap::new();
+ ///
+ /// map.insert("a", pending::<i32>());
+ /// map.insert("b", pending());
+ /// map.insert("c", pending());
+ ///
+ /// for (key, stream) in map.iter_mut() {
+ /// println!("({}, {:?})", key, stream);
+ /// }
+ /// ```
+ pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (K, V)> {
+ self.entries.iter_mut()
+ }
+
/// Creates an empty `StreamMap`.
///
/// The stream map is initially created with a capacity of `0`, so it will
@@ -217,7 +263,7 @@ impl<K, V> StreamMap<K, V> {
/// }
/// ```
pub fn keys(&self) -> impl Iterator<Item = &K> {
- self.entries.iter().map(|(k, _)| k)
+ self.iter().map(|(k, _)| k)
}
/// An iterator visiting all values in arbitrary order.
@@ -240,7 +286,7 @@ impl<K, V> StreamMap<K, V> {
/// }
/// ```
pub fn values(&self) -> impl Iterator<Item = &V> {
- self.entries.iter().map(|(_, v)| v)
+ self.iter().map(|(_, v)| v)
}
/// An iterator visiting all values mutably in arbitrary order.
@@ -263,7 +309,7 @@ impl<K, V> StreamMap<K, V> {
/// }
/// ```
pub fn values_mut(&mut self) -> impl Iterator<Item = &mut V> {
- self.entries.iter_mut().map(|(_, v)| v)
+ self.iter_mut().map(|(_, v)| v)
}
/// Returns the number of streams the map can hold without reallocating.
diff --git a/src/time/throttle.rs b/src/stream/throttle.rs
index d53a6f7..8f4a256 100644
--- a/src/time/throttle.rs
+++ b/src/stream/throttle.rs
@@ -1,7 +1,7 @@
//! Slow down a stream by enforcing a delay between items.
use crate::stream::Stream;
-use crate::time::{Delay, Duration, Instant};
+use crate::time::{Duration, Instant, Sleep};
use std::future::Future;
use std::marker::Unpin;
@@ -10,34 +10,14 @@ use std::task::{self, Poll};
use pin_project_lite::pin_project;
-/// Slows down a stream by enforcing a delay between items.
-/// They will be produced not more often than the specified interval.
-///
-/// # Example
-///
-/// Create a throttled stream.
-/// ```rust,no_run
-/// use std::time::Duration;
-/// use tokio::stream::StreamExt;
-/// use tokio::time::throttle;
-///
-/// # async fn dox() {
-/// let mut item_stream = throttle(Duration::from_secs(2), futures::stream::repeat("one"));
-///
-/// loop {
-/// // The string will be produced at most every 2 seconds
-/// println!("{:?}", item_stream.next().await);
-/// }
-/// # }
-/// ```
-pub fn throttle<T>(duration: Duration, stream: T) -> Throttle<T>
+pub(super) fn throttle<T>(duration: Duration, stream: T) -> Throttle<T>
where
T: Stream,
{
let delay = if duration == Duration::from_millis(0) {
None
} else {
- Some(Delay::new_timeout(Instant::now() + duration, duration))
+ Some(Sleep::new_timeout(Instant::now() + duration, duration))
};
Throttle {
@@ -54,7 +34,7 @@ pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct Throttle<T> {
// `None` when duration is zero.
- delay: Option<Delay>,
+ delay: Option<Sleep>,
duration: Duration,
// Set to true when `delay` has returned ready, but `stream` hasn't.
diff --git a/src/stream/timeout.rs b/src/stream/timeout.rs
index b8a2024..669973f 100644
--- a/src/stream/timeout.rs
+++ b/src/stream/timeout.rs
@@ -1,5 +1,5 @@
use crate::stream::{Fuse, Stream};
-use crate::time::{Delay, Elapsed, Instant};
+use crate::time::{error::Elapsed, Instant, Sleep};
use core::future::Future;
use core::pin::Pin;
@@ -14,7 +14,7 @@ pin_project! {
pub struct Timeout<S> {
#[pin]
stream: Fuse<S>,
- deadline: Delay,
+ deadline: Sleep,
duration: Duration,
poll_deadline: bool,
}
@@ -23,7 +23,7 @@ pin_project! {
impl<S: Stream> Timeout<S> {
pub(super) fn new(stream: S, duration: Duration) -> Self {
let next = Instant::now() + duration;
- let deadline = Delay::new_timeout(next, duration);
+ let deadline = Sleep::new_timeout(next, duration);
Timeout {
stream: Fuse::new(stream),
diff --git a/src/stream/try_next.rs b/src/stream/try_next.rs
index 59e0eb1..b21d279 100644
--- a/src/stream/try_next.rs
+++ b/src/stream/try_next.rs
@@ -1,22 +1,29 @@
use crate::stream::{Next, Stream};
use core::future::Future;
+use core::marker::PhantomPinned;
use core::pin::Pin;
use core::task::{Context, Poll};
+use pin_project_lite::pin_project;
-/// Future for the [`try_next`](super::StreamExt::try_next) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct TryNext<'a, St: ?Sized> {
- inner: Next<'a, St>,
+pin_project! {
+ /// Future for the [`try_next`](super::StreamExt::try_next) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct TryNext<'a, St: ?Sized> {
+ #[pin]
+ inner: Next<'a, St>,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
-impl<St: ?Sized + Unpin> Unpin for TryNext<'_, St> {}
-
impl<'a, St: ?Sized> TryNext<'a, St> {
pub(super) fn new(stream: &'a mut St) -> Self {
Self {
inner: Next::new(stream),
+ _pin: PhantomPinned,
}
}
}
@@ -24,7 +31,8 @@ impl<'a, St: ?Sized> TryNext<'a, St> {
impl<T, E, St: ?Sized + Stream<Item = Result<T, E>> + Unpin> Future for TryNext<'_, St> {
type Output = Result<Option<T>, E>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- Pin::new(&mut self.inner).poll(cx).map(Option::transpose)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ me.inner.poll(cx).map(Option::transpose)
}
}
diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs
index 6286334..fddb3a5 100644
--- a/src/sync/barrier.rs
+++ b/src/sync/barrier.rs
@@ -96,7 +96,7 @@ impl Barrier {
// wake everyone, increment the generation, and return
state
.waker
- .broadcast(state.generation)
+ .send(state.generation)
.expect("there is at least one receiver");
state.arrived = 0;
state.generation += 1;
@@ -110,9 +110,11 @@ impl Barrier {
let mut wait = self.wait.clone();
loop {
+ let _ = wait.changed().await;
+
// note that the first time through the loop, this _will_ yield a generation
// immediately, since we cloned a receiver that has never seen any values.
- if wait.recv().await.expect("sender hasn't been closed") >= generation {
+ if *wait.borrow() >= generation {
break;
}
}
diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs
index 070cd20..0b50e4f 100644
--- a/src/sync/batch_semaphore.rs
+++ b/src/sync/batch_semaphore.rs
@@ -1,3 +1,4 @@
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
//! # Implementation Details
//!
//! The semaphore is implemented using an intrusive linked list of waiters. An
@@ -36,7 +37,7 @@ pub(crate) struct Semaphore {
}
struct Waitlist {
- queue: LinkedList<Waiter>,
+ queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
closed: bool,
}
@@ -96,10 +97,13 @@ impl Semaphore {
/// Note that this reserves three bits of flags in the permit counter, but
/// we only actually use one of them. However, the previous semaphore
/// implementation used three bits, so we will continue to reserve them to
- /// avoid a breaking change if additional flags need to be aadded in the
+ /// avoid a breaking change if additional flags need to be added in the
/// future.
pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3;
const CLOSED: usize = 1;
+ // The least-significant bit in the number of permits is reserved to use
+ // as a flag indicating that the semaphore has been closed. Consequently
+ // PERMIT_SHIFT is used to leave that bit for that purpose.
const PERMIT_SHIFT: usize = 1;
/// Creates a new semaphore with the initial number of permits
@@ -120,6 +124,27 @@ impl Semaphore {
}
}
+ /// Creates a new semaphore with the initial number of permits
+ ///
+ /// Maximum number of permits on 32-bit platforms is `1<<29`.
+ ///
+ /// If the specified number of permits exceeds the maximum permit amount
+ /// Then the value will get clamped to the maximum number of permits.
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ pub(crate) const fn const_new(mut permits: usize) -> Self {
+ // NOTE: assertions and by extension panics are still being worked on: https://github.com/rust-lang/rust/issues/74925
+ // currently we just clamp the permit count when it exceeds the max
+ permits &= Self::MAX_PERMITS;
+
+ Self {
+ permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT),
+ waiters: Mutex::const_new(Waitlist {
+ queue: LinkedList::new(),
+ closed: false,
+ }),
+ }
+ }
+
/// Returns the current number of available permits
pub(crate) fn available_permits(&self) -> usize {
self.permits.load(Acquire) >> Self::PERMIT_SHIFT
@@ -134,16 +159,15 @@ impl Semaphore {
}
// Assign permits to the wait queue
- self.add_permits_locked(added, self.waiters.lock().unwrap());
+ self.add_permits_locked(added, self.waiters.lock());
}
/// Closes the semaphore. This prevents the semaphore from issuing new
/// permits and notifies all pending waiters.
// This will be used once the bounded MPSC is updated to use the new
// semaphore implementation.
- #[allow(dead_code)]
pub(crate) fn close(&self) {
- let mut waiters = self.waiters.lock().unwrap();
+ let mut waiters = self.waiters.lock();
// If the semaphore's permits counter has enough permits for an
// unqueued waiter to acquire all the permits it needs immediately,
// it won't touch the wait list. Therefore, we have to set a bit on
@@ -161,6 +185,11 @@ impl Semaphore {
}
}
+ /// Returns true if the semaphore is closed
+ pub(crate) fn is_closed(&self) -> bool {
+ self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
+ }
+
pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
@@ -170,8 +199,8 @@ impl Semaphore {
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let mut curr = self.permits.load(Acquire);
loop {
- // Has the semaphore closed?git
- if curr & Self::CLOSED > 0 {
+ // Has the semaphore closed?
+ if curr & Self::CLOSED == Self::CLOSED {
return Err(TryAcquireError::Closed);
}
@@ -203,7 +232,7 @@ impl Semaphore {
let mut lock = Some(waiters);
let mut is_empty = false;
while rem > 0 {
- let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap());
+ let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock());
'inner: for slot in &mut wakers[..] {
// Was the waiter assigned enough permits to wake it?
match waiters.queue.last() {
@@ -296,7 +325,7 @@ impl Semaphore {
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
- lock = Some(self.waiters.lock().unwrap());
+ lock = Some(self.waiters.lock());
}
match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
@@ -306,7 +335,7 @@ impl Semaphore {
if !queued {
return Ready(Ok(()));
} else if lock.is_none() {
- break self.waiters.lock().unwrap();
+ break self.waiters.lock();
}
}
break lock.expect("lock must be acquired before waiting");
@@ -357,7 +386,7 @@ impl Semaphore {
impl fmt::Debug for Semaphore {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Semaphore")
- .field("permits", &self.permits.load(Relaxed))
+ .field("permits", &self.available_permits())
.finish()
}
}
@@ -456,14 +485,7 @@ impl Drop for Acquire<'_> {
// This is where we ensure safety. The future is being dropped,
// which means we must ensure that the waiter entry is no longer stored
// in the linked list.
- let mut waiters = match self.semaphore.waiters.lock() {
- Ok(lock) => lock,
- // Removing the node from the linked list is necessary to ensure
- // safety. Even if the lock was poisoned, we need to make sure it is
- // removed from the linked list before dropping it --- otherwise,
- // the list will contain a dangling pointer to this node.
- Err(e) => e.into_inner(),
- };
+ let mut waiters = self.semaphore.waiters.lock();
// remove the entry from the list
let node = NonNull::from(&mut self.node);
@@ -506,20 +528,14 @@ impl TryAcquireError {
/// Returns `true` if the error was caused by a closed semaphore.
#[allow(dead_code)] // may be used later!
pub(crate) fn is_closed(&self) -> bool {
- match self {
- TryAcquireError::Closed => true,
- _ => false,
- }
+ matches!(self, TryAcquireError::Closed)
}
/// Returns `true` if the error was caused by calling `try_acquire` on a
/// semaphore with no available permits.
#[allow(dead_code)] // may be used later!
pub(crate) fn is_no_permits(&self) -> bool {
- match self {
- TryAcquireError::NoPermits => true,
- _ => false,
- }
+ matches!(self, TryAcquireError::NoPermits)
}
}
diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs
index 0c8716f..ee9aba0 100644
--- a/src/sync/broadcast.rs
+++ b/src/sync/broadcast.rs
@@ -21,7 +21,7 @@
//! ## Lagging
//!
//! As sent messages must be retained until **all** [`Receiver`] handles receive
-//! a clone, broadcast channels are suspectible to the "slow receiver" problem.
+//! a clone, broadcast channels are susceptible to the "slow receiver" problem.
//! In this case, all but one receiver are able to receive values at the rate
//! they are sent. Because one receiver is stalled, the channel starts to fill
//! up.
@@ -55,8 +55,8 @@
//! [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
//! [`Receiver`]: crate::sync::broadcast::Receiver
//! [`channel`]: crate::sync::broadcast::channel
-//! [`RecvError::Lagged`]: crate::sync::broadcast::RecvError::Lagged
-//! [`RecvError::Closed`]: crate::sync::broadcast::RecvError::Closed
+//! [`RecvError::Lagged`]: crate::sync::broadcast::error::RecvError::Lagged
+//! [`RecvError::Closed`]: crate::sync::broadcast::error::RecvError::Closed
//! [`recv`]: crate::sync::broadcast::Receiver::recv
//!
//! # Examples
@@ -107,6 +107,7 @@
//! assert_eq!(20, rx.recv().await.unwrap());
//! assert_eq!(30, rx.recv().await.unwrap());
//! }
+//! ```
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicUsize;
@@ -194,58 +195,99 @@ pub struct Receiver<T> {
/// Next position to read from
next: u64,
-
- /// Used to support the deprecated `poll_recv` fn
- waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
}
-/// Error returned by [`Sender::send`][Sender::send].
-///
-/// A **send** operation can only fail if there are no active receivers,
-/// implying that the message could never be received. The error contains the
-/// message being sent as a payload so it can be recovered.
-#[derive(Debug)]
-pub struct SendError<T>(pub T);
+pub mod error {
+ //! Broadcast error types
-/// An error returned from the [`recv`] function on a [`Receiver`].
-///
-/// [`recv`]: crate::sync::broadcast::Receiver::recv
-/// [`Receiver`]: crate::sync::broadcast::Receiver
-#[derive(Debug, PartialEq)]
-pub enum RecvError {
- /// There are no more active senders implying no further messages will ever
- /// be sent.
- Closed,
+ use std::fmt;
- /// The receiver lagged too far behind. Attempting to receive again will
- /// return the oldest message still retained by the channel.
+ /// Error returned by from the [`send`] function on a [`Sender`].
///
- /// Includes the number of skipped messages.
- Lagged(u64),
-}
+ /// A **send** operation can only fail if there are no active receivers,
+ /// implying that the message could never be received. The error contains the
+ /// message being sent as a payload so it can be recovered.
+ ///
+ /// [`send`]: crate::sync::broadcast::Sender::send
+ /// [`Sender`]: crate::sync::broadcast::Sender
+ #[derive(Debug)]
+ pub struct SendError<T>(pub T);
-/// An error returned from the [`try_recv`] function on a [`Receiver`].
-///
-/// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
-/// [`Receiver`]: crate::sync::broadcast::Receiver
-#[derive(Debug, PartialEq)]
-pub enum TryRecvError {
- /// The channel is currently empty. There are still active
- /// [`Sender`][Sender] handles, so data may yet become available.
- Empty,
-
- /// There are no more active senders implying no further messages will ever
- /// be sent.
- Closed,
-
- /// The receiver lagged too far behind and has been forcibly disconnected.
- /// Attempting to receive again will return the oldest message still
- /// retained by the channel.
- ///
- /// Includes the number of skipped messages.
- Lagged(u64),
+ impl<T> fmt::Display for SendError<T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "channel closed")
+ }
+ }
+
+ impl<T: fmt::Debug> std::error::Error for SendError<T> {}
+
+ /// An error returned from the [`recv`] function on a [`Receiver`].
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ #[derive(Debug, PartialEq)]
+ pub enum RecvError {
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind. Attempting to receive again will
+ /// return the oldest message still retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+ }
+
+ impl fmt::Display for RecvError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ RecvError::Closed => write!(f, "channel closed"),
+ RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
+ }
+ }
+ }
+
+ impl std::error::Error for RecvError {}
+
+ /// An error returned from the [`try_recv`] function on a [`Receiver`].
+ ///
+ /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ #[derive(Debug, PartialEq)]
+ pub enum TryRecvError {
+ /// The channel is currently empty. There are still active
+ /// [`Sender`] handles, so data may yet become available.
+ ///
+ /// [`Sender`]: crate::sync::broadcast::Sender
+ Empty,
+
+ /// There are no more active senders implying no further messages will ever
+ /// be sent.
+ Closed,
+
+ /// The receiver lagged too far behind and has been forcibly disconnected.
+ /// Attempting to receive again will return the oldest message still
+ /// retained by the channel.
+ ///
+ /// Includes the number of skipped messages.
+ Lagged(u64),
+ }
+
+ impl fmt::Display for TryRecvError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ TryRecvError::Empty => write!(f, "channel empty"),
+ TryRecvError::Closed => write!(f, "channel closed"),
+ TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
+ }
+ }
+ }
+
+ impl std::error::Error for TryRecvError {}
}
+use self::error::*;
+
/// Data shared between senders and receivers
struct Shared<T> {
/// slots in the channel
@@ -273,7 +315,7 @@ struct Tail {
closed: bool,
/// Receivers waiting for a value
- waiters: LinkedList<Waiter>,
+ waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
}
/// Slot in the buffer
@@ -373,8 +415,8 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// [`Sender::subscribe`]: crate::sync::broadcast::Sender::subscribe
/// [`Receiver`]: crate::sync::broadcast::Receiver
/// [`recv`]: crate::sync::broadcast::Receiver::recv
-/// [`SendError`]: crate::sync::broadcast::SendError
-/// [`RecvError`]: crate::sync::broadcast::RecvError
+/// [`SendError`]: crate::sync::broadcast::error::SendError
+/// [`RecvError`]: crate::sync::broadcast::error::RecvError
///
/// # Examples
///
@@ -400,7 +442,7 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
/// tx.send(20).unwrap();
/// }
/// ```
-pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
+pub fn channel<T: Clone>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "capacity is empty");
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
@@ -433,7 +475,6 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
let rx = Receiver {
shared: shared.clone(),
next: 0,
- waiter: None,
};
let tx = Sender { shared };
@@ -528,23 +569,7 @@ impl<T> Sender<T> {
/// ```
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
-
- let mut tail = shared.tail.lock().unwrap();
-
- if tail.rx_cnt == MAX_RECEIVERS {
- panic!("max receivers");
- }
-
- tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
- let next = tail.pos;
-
- drop(tail);
-
- Receiver {
- shared,
- next,
- waiter: None,
- }
+ new_receiver(shared)
}
/// Returns the number of active receivers
@@ -584,12 +609,12 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn receiver_count(&self) -> usize {
- let tail = self.shared.tail.lock().unwrap();
+ let tail = self.shared.tail.lock();
tail.rx_cnt
}
fn send2(&self, value: Option<T>) -> Result<usize, SendError<Option<T>>> {
- let mut tail = self.shared.tail.lock().unwrap();
+ let mut tail = self.shared.tail.lock();
if tail.rx_cnt == 0 {
return Err(SendError(value));
@@ -634,6 +659,22 @@ impl<T> Sender<T> {
}
}
+fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
+ let mut tail = shared.tail.lock();
+
+ if tail.rx_cnt == MAX_RECEIVERS {
+ panic!("max receivers");
+ }
+
+ tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
+
+ let next = tail.pos;
+
+ drop(tail);
+
+ Receiver { shared, next }
+}
+
impl Tail {
fn notify_rx(&mut self) {
while let Some(mut waiter) = self.waiters.pop_back() {
@@ -695,7 +736,7 @@ impl<T> Receiver<T> {
// the slot lock.
drop(slot);
- let mut tail = self.shared.tail.lock().unwrap();
+ let mut tail = self.shared.tail.lock();
// Acquire slot lock again
slot = self.shared.buffer[idx].read().unwrap();
@@ -784,106 +825,7 @@ impl<T> Receiver<T> {
}
}
-impl<T> Receiver<T>
-where
- T: Clone,
-{
- /// Attempts to return a pending value on this receiver without awaiting.
- ///
- /// This is useful for a flavor of "optimistic check" before deciding to
- /// await on a receiver.
- ///
- /// Compared with [`recv`], this function has three failure cases instead of one
- /// (one for closed, one for an empty buffer, one for a lagging receiver).
- ///
- /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have
- /// dropped, indicating that no further values can be sent on the channel.
- ///
- /// If the [`Receiver`] handle falls behind, once the channel is full, newly
- /// sent values will overwrite old values. At this point, a call to [`recv`]
- /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s
- /// internal cursor is updated to point to the oldest value still held by
- /// the channel. A subsequent call to [`try_recv`] will return this value
- /// **unless** it has been since overwritten. If there are no values to
- /// receive, `Err(TryRecvError::Empty)` is returned.
- ///
- /// [`recv`]: crate::sync::broadcast::Receiver::recv
- /// [`Receiver`]: crate::sync::broadcast::Receiver
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::sync::broadcast;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let (tx, mut rx) = broadcast::channel(16);
- ///
- /// assert!(rx.try_recv().is_err());
- ///
- /// tx.send(10).unwrap();
- ///
- /// let value = rx.try_recv().unwrap();
- /// assert_eq!(10, value);
- /// }
- /// ```
- pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
- let guard = self.recv_ref(None)?;
- guard.clone_value().ok_or(TryRecvError::Closed)
- }
-
- #[doc(hidden)]
- #[deprecated(since = "0.2.21", note = "use async fn recv()")]
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
- use Poll::{Pending, Ready};
-
- // The borrow checker prohibits calling `self.poll_ref` while passing in
- // a mutable ref to a field (as it should). To work around this,
- // `waiter` is first *removed* from `self` then `poll_recv` is called.
- //
- // However, for safety, we must ensure that `waiter` is **not** dropped.
- // It could be contained in the intrusive linked list. The `Receiver`
- // drop implementation handles cleanup.
- //
- // The guard pattern is used to ensure that, on return, even due to
- // panic, the waiter node is replaced on `self`.
-
- struct Guard<'a, T> {
- waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
- receiver: &'a mut Receiver<T>,
- }
-
- impl<'a, T> Drop for Guard<'a, T> {
- fn drop(&mut self) {
- self.receiver.waiter = self.waiter.take();
- }
- }
-
- let waiter = self.waiter.take().or_else(|| {
- Some(Box::pin(UnsafeCell::new(Waiter {
- queued: false,
- waker: None,
- pointers: linked_list::Pointers::new(),
- _p: PhantomPinned,
- })))
- });
-
- let guard = Guard {
- waiter,
- receiver: self,
- };
- let res = guard
- .receiver
- .recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker())));
-
- match res {
- Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)),
- Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)),
- Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))),
- Err(TryRecvError::Empty) => Pending,
- }
- }
-
+impl<T: Clone> Receiver<T> {
/// Receives the next value for this receiver.
///
/// Each [`Receiver`] handle will receive a clone of all values sent
@@ -948,54 +890,103 @@ where
/// assert_eq!(20, rx.recv().await.unwrap());
/// assert_eq!(30, rx.recv().await.unwrap());
/// }
+ /// ```
pub async fn recv(&mut self) -> Result<T, RecvError> {
let fut = Recv::<_, T>::new(Borrow(self));
fut.await
}
-}
-#[cfg(feature = "stream")]
-#[doc(hidden)]
-#[deprecated(since = "0.2.21", note = "use `into_stream()`")]
-impl<T> crate::stream::Stream for Receiver<T>
-where
- T: Clone,
-{
- type Item = Result<T, RecvError>;
-
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Result<T, RecvError>>> {
- #[allow(deprecated)]
- self.poll_recv(cx).map(|v| match v {
- Ok(v) => Some(Ok(v)),
- lag @ Err(RecvError::Lagged(_)) => Some(lag),
- Err(RecvError::Closed) => None,
- })
+ /// Attempts to return a pending value on this receiver without awaiting.
+ ///
+ /// This is useful for a flavor of "optimistic check" before deciding to
+ /// await on a receiver.
+ ///
+ /// Compared with [`recv`], this function has three failure cases instead of two
+ /// (one for closed, one for an empty buffer, one for a lagging receiver).
+ ///
+ /// `Err(TryRecvError::Closed)` is returned when all `Sender` halves have
+ /// dropped, indicating that no further values can be sent on the channel.
+ ///
+ /// If the [`Receiver`] handle falls behind, once the channel is full, newly
+ /// sent values will overwrite old values. At this point, a call to [`recv`]
+ /// will return with `Err(TryRecvError::Lagged)` and the [`Receiver`]'s
+ /// internal cursor is updated to point to the oldest value still held by
+ /// the channel. A subsequent call to [`try_recv`] will return this value
+ /// **unless** it has been since overwritten. If there are no values to
+ /// receive, `Err(TryRecvError::Empty)` is returned.
+ ///
+ /// [`recv`]: crate::sync::broadcast::Receiver::recv
+ /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv
+ /// [`Receiver`]: crate::sync::broadcast::Receiver
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = broadcast::channel(16);
+ ///
+ /// assert!(rx.try_recv().is_err());
+ ///
+ /// tx.send(10).unwrap();
+ ///
+ /// let value = rx.try_recv().unwrap();
+ /// assert_eq!(10, value);
+ /// }
+ /// ```
+ pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
+ let guard = self.recv_ref(None)?;
+ guard.clone_value().ok_or(TryRecvError::Closed)
+ }
+
+ /// Convert the receiver into a `Stream`.
+ ///
+ /// The conversion allows using `Receiver` with APIs that require stream
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::stream::StreamExt;
+ /// use tokio::sync::broadcast;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = broadcast::channel(128);
+ ///
+ /// tokio::spawn(async move {
+ /// for i in 0..10_i32 {
+ /// tx.send(i).unwrap();
+ /// }
+ /// });
+ ///
+ /// // Streams must be pinned to iterate.
+ /// tokio::pin! {
+ /// let stream = rx
+ /// .into_stream()
+ /// .filter(Result::is_ok)
+ /// .map(Result::unwrap)
+ /// .filter(|v| v % 2 == 0)
+ /// .map(|v| v + 1);
+ /// }
+ ///
+ /// while let Some(i) = stream.next().await {
+ /// println!("{}", i);
+ /// }
+ /// }
+ /// ```
+ #[cfg(feature = "stream")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
+ pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
+ Recv::new(Borrow(self))
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
- let mut tail = self.shared.tail.lock().unwrap();
-
- if let Some(waiter) = &self.waiter {
- // safety: tail lock is held
- let queued = waiter.with(|ptr| unsafe { (*ptr).queued });
-
- if queued {
- // Remove the node
- //
- // safety: tail lock is held and the wait node is verified to be in
- // the list.
- unsafe {
- waiter.with_mut(|ptr| {
- tail.waiters.remove((&mut *ptr).into());
- });
- }
- }
- }
+ let mut tail = self.shared.tail.lock();
tail.rx_cnt -= 1;
let until = tail.pos;
@@ -1070,48 +1061,6 @@ where
cfg_stream! {
use futures_core::Stream;
- impl<T: Clone> Receiver<T> {
- /// Convert the receiver into a `Stream`.
- ///
- /// The conversion allows using `Receiver` with APIs that require stream
- /// values.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::stream::StreamExt;
- /// use tokio::sync::broadcast;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let (tx, rx) = broadcast::channel(128);
- ///
- /// tokio::spawn(async move {
- /// for i in 0..10_i32 {
- /// tx.send(i).unwrap();
- /// }
- /// });
- ///
- /// // Streams must be pinned to iterate.
- /// tokio::pin! {
- /// let stream = rx
- /// .into_stream()
- /// .filter(Result::is_ok)
- /// .map(Result::unwrap)
- /// .filter(|v| v % 2 == 0)
- /// .map(|v| v + 1);
- /// }
- ///
- /// while let Some(i) = stream.next().await {
- /// println!("{}", i);
- /// }
- /// }
- /// ```
- pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
- Recv::new(Borrow(self))
- }
- }
-
impl<R, T: Clone> Stream for Recv<R, T>
where
R: AsMut<Receiver<T>>,
@@ -1141,7 +1090,7 @@ where
fn drop(&mut self) {
// Acquire the tail lock. This is required for safety before accessing
// the waiter node.
- let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap();
+ let mut tail = self.receiver.as_mut().shared.tail.lock();
// safety: tail lock is held
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
@@ -1211,27 +1160,4 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
}
}
-impl fmt::Display for RecvError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- RecvError::Closed => write!(f, "channel closed"),
- RecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
- }
- }
-}
-
-impl std::error::Error for RecvError {}
-
-impl fmt::Display for TryRecvError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- TryRecvError::Empty => write!(f, "channel empty"),
- TryRecvError::Closed => write!(f, "channel closed"),
- TryRecvError::Lagged(amt) => write!(f, "channel lagged by {}", amt),
- }
- }
-}
-
-impl std::error::Error for TryRecvError {}
-
fn is_unpin<T: Unpin>() {}
diff --git a/src/sync/cancellation_token.rs b/src/sync/cancellation_token.rs
deleted file mode 100644
index d60d8e0..0000000
--- a/src/sync/cancellation_token.rs
+++ /dev/null
@@ -1,861 +0,0 @@
-//! An asynchronously awaitable `CancellationToken`.
-//! The token allows to signal a cancellation request to one or more tasks.
-
-use crate::loom::sync::atomic::AtomicUsize;
-use crate::loom::sync::Mutex;
-use crate::util::intrusive_double_linked_list::{LinkedList, ListNode};
-
-use core::future::Future;
-use core::pin::Pin;
-use core::ptr::NonNull;
-use core::sync::atomic::Ordering;
-use core::task::{Context, Poll, Waker};
-
-/// A token which can be used to signal a cancellation request to one or more
-/// tasks.
-///
-/// Tasks can call [`CancellationToken::cancelled()`] in order to
-/// obtain a Future which will be resolved when cancellation is requested.
-///
-/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
-///
-/// # Examples
-///
-/// ```ignore
-/// use tokio::select;
-/// use tokio::scope::CancellationToken;
-///
-/// #[tokio::main]
-/// async fn main() {
-/// let token = CancellationToken::new();
-/// let cloned_token = token.clone();
-///
-/// let join_handle = tokio::spawn(async move {
-/// // Wait for either cancellation or a very long time
-/// select! {
-/// _ = cloned_token.cancelled() => {
-/// // The token was cancelled
-/// 5
-/// }
-/// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => {
-/// 99
-/// }
-/// }
-/// });
-///
-/// tokio::spawn(async move {
-/// tokio::time::delay_for(std::time::Duration::from_millis(10)).await;
-/// token.cancel();
-/// });
-///
-/// assert_eq!(5, join_handle.await.unwrap());
-/// }
-/// ```
-pub struct CancellationToken {
- inner: NonNull<CancellationTokenState>,
-}
-
-// Safety: The CancellationToken is thread-safe and can be moved between threads,
-// since all methods are internally synchronized.
-unsafe impl Send for CancellationToken {}
-unsafe impl Sync for CancellationToken {}
-
-/// A Future that is resolved once the corresponding [`CancellationToken`]
-/// was cancelled
-#[must_use = "futures do nothing unless polled"]
-pub struct WaitForCancellationFuture<'a> {
- /// The CancellationToken that is associated with this WaitForCancellationFuture
- cancellation_token: Option<&'a CancellationToken>,
- /// Node for waiting at the cancellation_token
- wait_node: ListNode<WaitQueueEntry>,
- /// Whether this future was registered at the token yet as a waiter
- is_registered: bool,
-}
-
-// Safety: Futures can be sent between threads as long as the underlying
-// cancellation_token is thread-safe (Sync),
-// which allows to poll/register/unregister from a different thread.
-unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
-
-// ===== impl CancellationToken =====
-
-impl core::fmt::Debug for CancellationToken {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- f.debug_struct("CancellationToken")
- .field("is_cancelled", &self.is_cancelled())
- .finish()
- }
-}
-
-impl Clone for CancellationToken {
- fn clone(&self) -> Self {
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- let inner = self.state();
-
- // Tokens are cloned by increasing their refcount
- let current_state = inner.snapshot();
- inner.increment_refcount(current_state);
-
- CancellationToken { inner: self.inner }
- }
-}
-
-impl Drop for CancellationToken {
- fn drop(&mut self) {
- let token_state_pointer = self.inner;
-
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- let inner = unsafe { &mut *self.inner.as_ptr() };
-
- let mut current_state = inner.snapshot();
-
- // We need to safe the parent, since the state might be released by the
- // next call
- let parent = inner.parent;
-
- // Drop our own refcount
- current_state = inner.decrement_refcount(current_state);
-
- // If this was the last reference, unregister from the parent
- if current_state.refcount == 0 {
- if let Some(mut parent) = parent {
- // Safety: Since we still retain a reference on the parent, it must be valid.
- let parent = unsafe { parent.as_mut() };
- parent.unregister_child(token_state_pointer, current_state);
- }
- }
- }
-}
-
-impl CancellationToken {
- /// Creates a new CancellationToken in the non-cancelled state.
- pub fn new() -> CancellationToken {
- let state = Box::new(CancellationTokenState::new(
- None,
- StateSnapshot {
- cancel_state: CancellationState::NotCancelled,
- has_parent_ref: false,
- refcount: 1,
- },
- ));
-
- // Safety: We just created the Box. The pointer is guaranteed to be
- // not null
- CancellationToken {
- inner: unsafe { NonNull::new_unchecked(Box::into_raw(state)) },
- }
- }
-
- /// Returns a reference to the utilized `CancellationTokenState`.
- fn state(&self) -> &CancellationTokenState {
- // Safety: The state inside a `CancellationToken` is always valid, since
- // is reference counted
- unsafe { &*self.inner.as_ptr() }
- }
-
- /// Creates a `CancellationToken` which will get cancelled whenever the
- /// current token gets cancelled.
- ///
- /// If the current token is already cancelled, the child token will get
- /// returned in cancelled state.
- ///
- /// # Examples
- ///
- /// ```ignore
- /// use tokio::select;
- /// use tokio::scope::CancellationToken;
- ///
- /// #[tokio::main]
- /// async fn main() {
- /// let token = CancellationToken::new();
- /// let child_token = token.child_token();
- ///
- /// let join_handle = tokio::spawn(async move {
- /// // Wait for either cancellation or a very long time
- /// select! {
- /// _ = child_token.cancelled() => {
- /// // The token was cancelled
- /// 5
- /// }
- /// _ = tokio::time::delay_for(std::time::Duration::from_secs(9999)) => {
- /// 99
- /// }
- /// }
- /// });
- ///
- /// tokio::spawn(async move {
- /// tokio::time::delay_for(std::time::Duration::from_millis(10)).await;
- /// token.cancel();
- /// });
- ///
- /// assert_eq!(5, join_handle.await.unwrap());
- /// }
- /// ```
- pub fn child_token(&self) -> CancellationToken {
- let inner = self.state();
-
- // Increment the refcount of this token. It will be referenced by the
- // child, independent of whether the child is immediately cancelled or
- // not.
- let _current_state = inner.increment_refcount(inner.snapshot());
-
- let mut unpacked_child_state = StateSnapshot {
- has_parent_ref: true,
- refcount: 1,
- cancel_state: CancellationState::NotCancelled,
- };
- let mut child_token_state = Box::new(CancellationTokenState::new(
- Some(self.inner),
- unpacked_child_state,
- ));
-
- {
- let mut guard = inner.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // This task was already cancelled. In this case we should not
- // insert the child into the list, since it would never get removed
- // from the list.
- (*child_token_state.synchronized.lock().unwrap()).is_cancelled = true;
- unpacked_child_state.cancel_state = CancellationState::Cancelled;
- // Since it's not in the list, the parent doesn't need to retain
- // a reference to it.
- unpacked_child_state.has_parent_ref = false;
- child_token_state
- .state
- .store(unpacked_child_state.pack(), Ordering::SeqCst);
- } else {
- if let Some(mut first_child) = guard.first_child {
- child_token_state.from_parent.next_peer = Some(first_child);
- // Safety: We manipulate other child task inside the Mutex
- // and retain a parent reference on it. The child token can't
- // get invalidated while the Mutex is held.
- unsafe {
- first_child.as_mut().from_parent.prev_peer =
- Some((&mut *child_token_state).into())
- };
- }
- guard.first_child = Some((&mut *child_token_state).into());
- }
- };
-
- let child_token_ptr = Box::into_raw(child_token_state);
- // Safety: We just created the pointer from a `Box`
- CancellationToken {
- inner: unsafe { NonNull::new_unchecked(child_token_ptr) },
- }
- }
-
- /// Cancel the [`CancellationToken`] and all child tokens which had been
- /// derived from it.
- ///
- /// This will wake up all tasks which are waiting for cancellation.
- pub fn cancel(&self) {
- self.state().cancel();
- }
-
- /// Returns `true` if the `CancellationToken` had been cancelled
- pub fn is_cancelled(&self) -> bool {
- self.state().is_cancelled()
- }
-
- /// Returns a `Future` that gets fulfilled when cancellation is requested.
- pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
- WaitForCancellationFuture {
- cancellation_token: Some(self),
- wait_node: ListNode::new(WaitQueueEntry::new()),
- is_registered: false,
- }
- }
-
- unsafe fn register(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- self.state().register(wait_node, cx)
- }
-
- fn check_for_cancellation(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- self.state().check_for_cancellation(wait_node, cx)
- }
-
- fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
- self.state().unregister(wait_node)
- }
-}
-
-// ===== impl WaitForCancellationFuture =====
-
-impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- f.debug_struct("WaitForCancellationFuture").finish()
- }
-}
-
-impl<'a> Future for WaitForCancellationFuture<'a> {
- type Output = ();
-
- fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
- // Safety: We do not move anything out of `WaitForCancellationFuture`
- let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) };
-
- let cancellation_token = mut_self
- .cancellation_token
- .expect("polled WaitForCancellationFuture after completion");
-
- let poll_res = if !mut_self.is_registered {
- // Safety: The `ListNode` is pinned through the Future,
- // and we will unregister it in `WaitForCancellationFuture::drop`
- // before the Future is dropped and the memory reference is invalidated.
- unsafe { cancellation_token.register(&mut mut_self.wait_node, cx) }
- } else {
- cancellation_token.check_for_cancellation(&mut mut_self.wait_node, cx)
- };
-
- if let Poll::Ready(()) = poll_res {
- // The cancellation_token was signalled
- mut_self.cancellation_token = None;
- // A signalled Token means the Waker won't be enqueued anymore
- mut_self.is_registered = false;
- mut_self.wait_node.task = None;
- } else {
- // This `Future` and its stored `Waker` stay registered at the
- // `CancellationToken`
- mut_self.is_registered = true;
- }
-
- poll_res
- }
-}
-
-impl<'a> Drop for WaitForCancellationFuture<'a> {
- fn drop(&mut self) {
- // If this WaitForCancellationFuture has been polled and it was added to the
- // wait queue at the cancellation_token, it must be removed before dropping.
- // Otherwise the cancellation_token would access invalid memory.
- if let Some(token) = self.cancellation_token {
- if self.is_registered {
- token.unregister(&mut self.wait_node);
- }
- }
- }
-}
-
-/// Tracks how the future had interacted with the [`CancellationToken`]
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-enum PollState {
- /// The task has never interacted with the [`CancellationToken`].
- New,
- /// The task was added to the wait queue at the [`CancellationToken`].
- Waiting,
- /// The task has been polled to completion.
- Done,
-}
-
-/// Tracks the WaitForCancellationFuture waiting state.
-/// Access to this struct is synchronized through the mutex in the CancellationToken.
-struct WaitQueueEntry {
- /// The task handle of the waiting task
- task: Option<Waker>,
- // Current polling state. This state is only updated inside the Mutex of
- // the CancellationToken.
- state: PollState,
-}
-
-impl WaitQueueEntry {
- /// Creates a new WaitQueueEntry
- fn new() -> WaitQueueEntry {
- WaitQueueEntry {
- task: None,
- state: PollState::New,
- }
- }
-}
-
-struct SynchronizedState {
- waiters: LinkedList<WaitQueueEntry>,
- first_child: Option<NonNull<CancellationTokenState>>,
- is_cancelled: bool,
-}
-
-impl SynchronizedState {
- fn new() -> Self {
- Self {
- waiters: LinkedList::new(),
- first_child: None,
- is_cancelled: false,
- }
- }
-}
-
-/// Information embedded in child tokens which is synchronized through the Mutex
-/// in their parent.
-struct SynchronizedThroughParent {
- next_peer: Option<NonNull<CancellationTokenState>>,
- prev_peer: Option<NonNull<CancellationTokenState>>,
-}
-
-/// Possible states of a `CancellationToken`
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-enum CancellationState {
- NotCancelled = 0,
- Cancelling = 1,
- Cancelled = 2,
-}
-
-impl CancellationState {
- fn pack(self) -> usize {
- self as usize
- }
-
- fn unpack(value: usize) -> Self {
- match value {
- 0 => CancellationState::NotCancelled,
- 1 => CancellationState::Cancelling,
- 2 => CancellationState::Cancelled,
- _ => unreachable!("Invalid value"),
- }
- }
-}
-
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-struct StateSnapshot {
- /// The amount of references to this particular CancellationToken.
- /// `CancellationToken` structs hold these references to a `CancellationTokenState`.
- /// Also the state is referenced by the state of each child.
- refcount: usize,
- /// Whether the state is still referenced by it's parent and can therefore
- /// not be freed.
- has_parent_ref: bool,
- /// Whether the token is cancelled
- cancel_state: CancellationState,
-}
-
-impl StateSnapshot {
- /// Packs the snapshot into a `usize`
- fn pack(self) -> usize {
- self.refcount << 3 | if self.has_parent_ref { 4 } else { 0 } | self.cancel_state.pack()
- }
-
- /// Unpacks the snapshot from a `usize`
- fn unpack(value: usize) -> Self {
- let refcount = value >> 3;
- let has_parent_ref = value & 4 != 0;
- let cancel_state = CancellationState::unpack(value & 0x03);
-
- StateSnapshot {
- refcount,
- has_parent_ref,
- cancel_state,
- }
- }
-
- /// Whether this `CancellationTokenState` is still referenced by any
- /// `CancellationToken`.
- fn has_refs(&self) -> bool {
- self.refcount != 0 || self.has_parent_ref
- }
-}
-
-/// The maximum permitted amount of references to a CancellationToken. This
-/// is derived from the intent to never use more than 32bit in the `Snapshot`.
-const MAX_REFS: u32 = (std::u32::MAX - 7) >> 3;
-
-/// Internal state of the `CancellationToken` pair above
-struct CancellationTokenState {
- state: AtomicUsize,
- parent: Option<NonNull<CancellationTokenState>>,
- from_parent: SynchronizedThroughParent,
- synchronized: Mutex<SynchronizedState>,
-}
-
-impl CancellationTokenState {
- fn new(
- parent: Option<NonNull<CancellationTokenState>>,
- state: StateSnapshot,
- ) -> CancellationTokenState {
- CancellationTokenState {
- parent,
- from_parent: SynchronizedThroughParent {
- prev_peer: None,
- next_peer: None,
- },
- state: AtomicUsize::new(state.pack()),
- synchronized: Mutex::new(SynchronizedState::new()),
- }
- }
-
- /// Returns a snapshot of the current atomic state of the token
- fn snapshot(&self) -> StateSnapshot {
- StateSnapshot::unpack(self.state.load(Ordering::SeqCst))
- }
-
- fn atomic_update_state<F>(&self, mut current_state: StateSnapshot, func: F) -> StateSnapshot
- where
- F: Fn(StateSnapshot) -> StateSnapshot,
- {
- let mut current_packed_state = current_state.pack();
- loop {
- let next_state = func(current_state);
- match self.state.compare_exchange(
- current_packed_state,
- next_state.pack(),
- Ordering::SeqCst,
- Ordering::SeqCst,
- ) {
- Ok(_) => {
- return next_state;
- }
- Err(actual) => {
- current_packed_state = actual;
- current_state = StateSnapshot::unpack(actual);
- }
- }
- }
- }
-
- fn increment_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
- self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- if state.refcount >= MAX_REFS as usize {
- eprintln!("[ERROR] Maximum reference count for CancellationToken was exceeded");
- std::process::abort();
- }
- state.refcount += 1;
- state
- })
- }
-
- fn decrement_refcount(&self, current_state: StateSnapshot) -> StateSnapshot {
- let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- state.refcount -= 1;
- state
- });
-
- // Drop the State if it is not referenced anymore
- if !current_state.has_refs() {
- // Safety: `CancellationTokenState` is always stored in refcounted
- // Boxes
- let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
- }
-
- current_state
- }
-
- fn remove_parent_ref(&self, current_state: StateSnapshot) -> StateSnapshot {
- let current_state = self.atomic_update_state(current_state, |mut state: StateSnapshot| {
- state.has_parent_ref = false;
- state
- });
-
- // Drop the State if it is not referenced anymore
- if !current_state.has_refs() {
- // Safety: `CancellationTokenState` is always stored in refcounted
- // Boxes
- let _ = unsafe { Box::from_raw(self as *const Self as *mut Self) };
- }
-
- current_state
- }
-
- /// Unregisters a child from the parent token.
- /// The child tokens state is not exactly known at this point in time.
- /// If the parent token is cancelled, the child token gets removed from the
- /// parents list, and might therefore already have been freed. If the parent
- /// token is not cancelled, the child token is still valid.
- fn unregister_child(
- &mut self,
- mut child_state: NonNull<CancellationTokenState>,
- current_child_state: StateSnapshot,
- ) {
- let removed_child = {
- // Remove the child toke from the parents linked list
- let mut guard = self.synchronized.lock().unwrap();
- if !guard.is_cancelled {
- // Safety: Since the token was not cancelled, the child must
- // still be in the list and valid.
- let mut child_state = unsafe { child_state.as_mut() };
- debug_assert!(child_state.snapshot().has_parent_ref);
-
- if guard.first_child == Some(child_state.into()) {
- guard.first_child = child_state.from_parent.next_peer;
- }
- // Safety: If peers wouldn't be valid anymore, they would try
- // to remove themselves from the list. This would require locking
- // the Mutex that we currently own.
- unsafe {
- if let Some(mut prev_peer) = child_state.from_parent.prev_peer {
- prev_peer.as_mut().from_parent.next_peer =
- child_state.from_parent.next_peer;
- }
- if let Some(mut next_peer) = child_state.from_parent.next_peer {
- next_peer.as_mut().from_parent.prev_peer =
- child_state.from_parent.prev_peer;
- }
- }
- child_state.from_parent.prev_peer = None;
- child_state.from_parent.next_peer = None;
-
- // The child is no longer referenced by the parent, since we were able
- // to remove its reference from the parents list.
- true
- } else {
- // Do not touch the linked list anymore. If the parent is cancelled
- // it will move all childs outside of the Mutex and manipulate
- // the pointers there. Manipulating the pointers here too could
- // lead to races. Therefore leave them just as as and let the
- // parent deal with it. The parent will make sure to retain a
- // reference to this state as long as it manipulates the list
- // pointers. Therefore the pointers are not dangling.
- false
- }
- };
-
- if removed_child {
- // If the token removed itself from the parents list, it can reset
- // the the parent ref status. If it is isn't able to do so, because the
- // parent removed it from the list, there is no need to do this.
- // The parent ref acts as as another reference count. Therefore
- // removing this reference can free the object.
- // Safety: The token was in the list. This means the parent wasn't
- // cancelled before, and the token must still be alive.
- unsafe { child_state.as_mut().remove_parent_ref(current_child_state) };
- }
-
- // Decrement the refcount on the parent and free it if necessary
- self.decrement_refcount(self.snapshot());
- }
-
- fn cancel(&self) {
- // Move the state of the CancellationToken from `NotCancelled` to `Cancelling`
- let mut current_state = self.snapshot();
-
- let state_after_cancellation = loop {
- if current_state.cancel_state != CancellationState::NotCancelled {
- // Another task already initiated the cancellation
- return;
- }
-
- let mut next_state = current_state;
- next_state.cancel_state = CancellationState::Cancelling;
- match self.state.compare_exchange(
- current_state.pack(),
- next_state.pack(),
- Ordering::SeqCst,
- Ordering::SeqCst,
- ) {
- Ok(_) => break next_state,
- Err(actual) => current_state = StateSnapshot::unpack(actual),
- }
- };
-
- // This task cancelled the token
-
- // Take the task list out of the Token
- // We do not want to cancel child token inside this lock. If one of the
- // child tasks would have additional child tokens, we would recursively
- // take locks.
-
- // Doing this action has an impact if the child token is dropped concurrently:
- // It will try to deregister itself from the parent task, but can not find
- // itself in the task list anymore. Therefore it needs to assume the parent
- // has extracted the list and will process it. It may not modify the list.
- // This is OK from a memory safety perspective, since the parent still
- // retains a reference to the child task until it finished iterating over
- // it.
-
- let mut first_child = {
- let mut guard = self.synchronized.lock().unwrap();
- // Save the cancellation also inside the Mutex
- // This allows child tokens which want to detach themselves to detect
- // that this is no longer required since the parent cleared the list.
- guard.is_cancelled = true;
-
- // Wakeup all waiters
- // This happens inside the lock to make cancellation reliable
- // If we would access waiters outside of the lock, the pointers
- // may no longer be valid.
- // Typically this shouldn't be an issue, since waking a task should
- // only move it from the blocked into the ready state and not have
- // further side effects.
-
- // Use a reverse iterator, so that the oldest waiter gets
- // scheduled first
- guard.waiters.reverse_drain(|waiter| {
- // We are not allowed to move the `Waker` out of the list node.
- // The `Future` relies on the fact that the old `Waker` stays there
- // as long as the `Future` has not completed in order to perform
- // the `will_wake()` check.
- // Therefore `wake_by_ref` is used instead of `wake()`
- if let Some(handle) = &mut waiter.task {
- handle.wake_by_ref();
- }
- // Mark the waiter to have been removed from the list.
- waiter.state = PollState::Done;
- });
-
- guard.first_child.take()
- };
-
- while let Some(mut child) = first_child {
- // Safety: We know this is a valid pointer since it is in our child pointer
- // list. It can't have been freed in between, since we retain a a reference
- // to each child.
- let mut_child = unsafe { child.as_mut() };
-
- // Get the next child and clean up list pointers
- first_child = mut_child.from_parent.next_peer;
- mut_child.from_parent.prev_peer = None;
- mut_child.from_parent.next_peer = None;
-
- // Cancel the child task
- mut_child.cancel();
-
- // Drop the parent reference. This `CancellationToken` is not interested
- // in interacting with the child anymore.
- // This is ONLY allowed once we promised not to touch the state anymore
- // after this interaction.
- mut_child.remove_parent_ref(mut_child.snapshot());
- }
-
- // The cancellation has completed
- // At this point in time tasks which registered a wait node can be sure
- // that this wait node already had been dequeued from the list without
- // needing to inspect the list.
- self.atomic_update_state(state_after_cancellation, |mut state| {
- state.cancel_state = CancellationState::Cancelled;
- state
- });
- }
-
- /// Returns `true` if the `CancellationToken` had been cancelled
- fn is_cancelled(&self) -> bool {
- let current_state = self.snapshot();
- current_state.cancel_state != CancellationState::NotCancelled
- }
-
- /// Registers a waiting task at the `CancellationToken`.
- /// Safety: This method is only safe as long as the waiting waiting task
- /// will properly unregister the wait node before it gets moved.
- unsafe fn register(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- debug_assert_eq!(PollState::New, wait_node.state);
- let current_state = self.snapshot();
-
- // Perform an optimistic cancellation check before. This is not strictly
- // necessary since we also check for cancellation in the Mutex, but
- // reduces the necessary work to be performed for tasks which already
- // had been cancelled.
- if current_state.cancel_state != CancellationState::NotCancelled {
- return Poll::Ready(());
- }
-
- // So far the token is not cancelled. However it could be cancelld before
- // we get the chance to store the `Waker`. Therfore we need to check
- // for cancellation again inside the mutex.
- let mut guard = self.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // Cancellation was signalled
- wait_node.state = PollState::Done;
- Poll::Ready(())
- } else {
- // Added the task to the wait queue
- wait_node.task = Some(cx.waker().clone());
- wait_node.state = PollState::Waiting;
- guard.waiters.add_front(wait_node);
- Poll::Pending
- }
- }
-
- fn check_for_cancellation(
- &self,
- wait_node: &mut ListNode<WaitQueueEntry>,
- cx: &mut Context<'_>,
- ) -> Poll<()> {
- debug_assert!(
- wait_node.task.is_some(),
- "Method can only be called after task had been registered"
- );
-
- let current_state = self.snapshot();
-
- if current_state.cancel_state != CancellationState::NotCancelled {
- // If the cancellation had been fully completed we know that our `Waker`
- // is no longer registered at the `CancellationToken`.
- // Otherwise the cancel call may or may not yet have iterated
- // through the waiters list and removed the wait nodes.
- // If it hasn't yet, we need to remove it. Otherwise an attempt to
- // reuse the `wait_node´ might get freed due to the `WaitForCancellationFuture`
- // getting dropped before the cancellation had interacted with it.
- if current_state.cancel_state != CancellationState::Cancelled {
- self.unregister(wait_node);
- }
- Poll::Ready(())
- } else {
- // Check if we need to swap the `Waker`. This will make the check more
- // expensive, since the `Waker` is synchronized through the Mutex.
- // If we don't need to perform a `Waker` update, an atomic check for
- // cancellation is sufficient.
- let need_waker_update = wait_node
- .task
- .as_ref()
- .map(|waker| waker.will_wake(cx.waker()))
- .unwrap_or(true);
-
- if need_waker_update {
- let guard = self.synchronized.lock().unwrap();
- if guard.is_cancelled {
- // Cancellation was signalled. Since this cancellation signal
- // is set inside the Mutex, the old waiter must already have
- // been removed from the waiting list
- debug_assert_eq!(PollState::Done, wait_node.state);
- wait_node.task = None;
- Poll::Ready(())
- } else {
- // The WaitForCancellationFuture is already in the queue.
- // The CancellationToken can't have been cancelled,
- // since this would change the is_cancelled flag inside the mutex.
- // Therefore we just have to update the Waker. A follow-up
- // cancellation will always use the new waker.
- wait_node.task = Some(cx.waker().clone());
- Poll::Pending
- }
- } else {
- // Do nothing. If the token gets cancelled, this task will get
- // woken again and can fetch the cancellation.
- Poll::Pending
- }
- }
- }
-
- fn unregister(&self, wait_node: &mut ListNode<WaitQueueEntry>) {
- debug_assert!(
- wait_node.task.is_some(),
- "waiter can not be active without task"
- );
-
- let mut guard = self.synchronized.lock().unwrap();
- // WaitForCancellationFuture only needs to get removed if it has been added to
- // the wait queue of the CancellationToken.
- // This has happened in the PollState::Waiting case.
- if let PollState::Waiting = wait_node.state {
- // Safety: Due to the state, we know that the node must be part
- // of the waiter list
- if !unsafe { guard.waiters.remove(wait_node) } {
- // Panic if the address isn't found. This can only happen if the contract was
- // violated, e.g. the WaitQueueEntry got moved after the initial poll.
- panic!("Future could not be removed from wait queue");
- }
- wait_node.state = PollState::Done;
- }
- wait_node.task = None;
- }
-}
diff --git a/src/sync/mod.rs b/src/sync/mod.rs
index 3d96106..57ae277 100644
--- a/src/sync/mod.rs
+++ b/src/sync/mod.rs
@@ -20,7 +20,7 @@
//! few flavors of channels provided by Tokio. Each channel flavor supports
//! different message passing patterns. When a channel supports multiple
//! producers, many separate tasks may **send** messages. When a channel
-//! supports muliple consumers, many different separate tasks may **receive**
+//! supports multiple consumers, many different separate tasks may **receive**
//! messages.
//!
//! Tokio provides many different channel flavors as different message passing
@@ -106,7 +106,7 @@
//!
//! #[tokio::main]
//! async fn main() {
-//! let (mut tx, mut rx) = mpsc::channel(100);
+//! let (tx, mut rx) = mpsc::channel(100);
//!
//! tokio::spawn(async move {
//! for i in 0..10 {
@@ -150,7 +150,7 @@
//! for _ in 0..10 {
//! // Each task needs its own `tx` handle. This is done by cloning the
//! // original handle.
-//! let mut tx = tx.clone();
+//! let tx = tx.clone();
//!
//! tokio::spawn(async move {
//! tx.send(&b"data to write"[..]).await.unwrap();
@@ -213,7 +213,7 @@
//!
//! // Spawn tasks that will send the increment command.
//! for _ in 0..10 {
-//! let mut cmd_tx = cmd_tx.clone();
+//! let cmd_tx = cmd_tx.clone();
//!
//! join_handles.push(tokio::spawn(async move {
//! let (resp_tx, resp_rx) = oneshot::channel();
@@ -322,7 +322,7 @@
//! tokio::spawn(async move {
//! loop {
//! // Wait 10 seconds between checks
-//! time::delay_for(Duration::from_secs(10)).await;
+//! time::sleep(Duration::from_secs(10)).await;
//!
//! // Load the configuration file
//! let new_config = Config::load_from_file().await.unwrap();
@@ -330,7 +330,7 @@
//! // If the configuration changed, send the new config value
//! // on the watch channel.
//! if new_config != config {
-//! tx.broadcast(new_config.clone()).unwrap();
+//! tx.send(new_config.clone()).unwrap();
//! config = new_config;
//! }
//! }
@@ -355,17 +355,15 @@
//! let op = my_async_operation();
//! tokio::pin!(op);
//!
-//! // Receive the **initial** configuration value. As this is the
-//! // first time the config is received from the watch, it will
-//! // always complete immediatedly.
-//! let mut conf = rx.recv().await.unwrap();
+//! // Get the initial config value
+//! let mut conf = rx.borrow().clone();
//!
//! let mut op_start = Instant::now();
-//! let mut delay = time::delay_until(op_start + conf.timeout);
+//! let mut sleep = time::sleep_until(op_start + conf.timeout);
//!
//! loop {
//! tokio::select! {
-//! _ = &mut delay => {
+//! _ = &mut sleep => {
//! // The operation elapsed. Restart it
//! op.set(my_async_operation());
//!
@@ -373,14 +371,14 @@
//! op_start = Instant::now();
//!
//! // Restart the timeout
-//! delay = time::delay_until(op_start + conf.timeout);
+//! sleep = time::sleep_until(op_start + conf.timeout);
//! }
-//! new_conf = rx.recv() => {
-//! conf = new_conf.unwrap();
+//! _ = rx.changed() => {
+//! conf = rx.borrow().clone();
//!
//! // The configuration has been updated. Update the
-//! // `delay` using the new `timeout` value.
-//! delay.reset(op_start + conf.timeout);
+//! // `sleep` using the new `timeout` value.
+//! sleep.reset(op_start + conf.timeout);
//! }
//! _ = &mut op => {
//! // The operation completed!
@@ -399,14 +397,14 @@
//! }
//! ```
//!
-//! [`watch` channel]: crate::sync::watch
-//! [`broadcast` channel]: crate::sync::broadcast
+//! [`watch` channel]: mod@crate::sync::watch
+//! [`broadcast` channel]: mod@crate::sync::broadcast
//!
//! # State synchronization
//!
//! The remaining synchronization primitives focus on synchronizing state.
//! These are asynchronous equivalents to versions provided by `std`. They
-//! operate in a similar way as their `std` counterparts parts but will wait
+//! operate in a similar way as their `std` counterparts but will wait
//! asynchronously instead of blocking the thread.
//!
//! * [`Barrier`](Barrier) Ensures multiple tasks will wait for each other to
@@ -434,23 +432,17 @@ cfg_sync! {
pub mod broadcast;
- cfg_unstable! {
- mod cancellation_token;
- pub use cancellation_token::{CancellationToken, WaitForCancellationFuture};
- }
-
pub mod mpsc;
mod mutex;
pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard};
- mod notify;
+ pub(crate) mod notify;
pub use notify::Notify;
pub mod oneshot;
pub(crate) mod batch_semaphore;
- pub(crate) mod semaphore_ll;
mod semaphore;
pub use semaphore::{Semaphore, SemaphorePermit, OwnedSemaphorePermit};
@@ -464,20 +456,30 @@ cfg_sync! {
}
cfg_not_sync! {
+ #[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))]
+ pub(crate) mod batch_semaphore;
+
+ cfg_fs! {
+ mod mutex;
+ pub(crate) use mutex::Mutex;
+ }
+
+ #[cfg(any(feature = "rt", feature = "signal", all(unix, feature = "process")))]
+ pub(crate) mod notify;
+
cfg_atomic_waker_impl! {
mod task;
pub(crate) use task::AtomicWaker;
}
#[cfg(any(
- feature = "rt-core",
+ feature = "rt",
feature = "process",
feature = "signal"))]
pub(crate) mod oneshot;
- cfg_signal! {
+ cfg_signal_internal! {
pub(crate) mod mpsc;
- pub(crate) mod semaphore_ll;
}
}
diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs
index 7bf1619..e062f2b 100644
--- a/src/sync/mpsc/block.rs
+++ b/src/sync/mpsc/block.rs
@@ -1,8 +1,6 @@
-use crate::loom::{
- cell::UnsafeCell,
- sync::atomic::{AtomicPtr, AtomicUsize},
- thread,
-};
+use crate::loom::cell::UnsafeCell;
+use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
+use crate::loom::thread;
use std::mem::MaybeUninit;
use std::ops;
diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs
index afca8c5..06b3717 100644
--- a/src/sync/mpsc/bounded.rs
+++ b/src/sync/mpsc/bounded.rs
@@ -1,6 +1,6 @@
+use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
-use crate::sync::mpsc::error::{ClosedError, SendError, TryRecvError, TrySendError};
-use crate::sync::semaphore_ll as semaphore;
+use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
cfg_time! {
use crate::sync::mpsc::error::SendTimeoutError;
@@ -8,6 +8,7 @@ cfg_time! {
}
use std::fmt;
+#[cfg(any(feature = "signal", feature = "process", feature = "stream"))]
use std::task::{Context, Poll};
/// Send values to the associated `Receiver`.
@@ -17,20 +18,14 @@ pub struct Sender<T> {
chan: chan::Tx<T, Semaphore>,
}
-impl<T> Clone for Sender<T> {
- fn clone(&self) -> Self {
- Sender {
- chan: self.chan.clone(),
- }
- }
-}
-
-impl<T> fmt::Debug for Sender<T> {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Sender")
- .field("chan", &self.chan)
- .finish()
- }
+/// Permit to send one value into the channel.
+///
+/// `Permit` values are returned by [`Sender::reserve()`] and are used to
+/// guarantee channel capacity before generating a message to send.
+///
+/// [`Sender::reserve()`]: Sender::reserve
+pub struct Permit<'a, T> {
+ chan: &'a chan::Tx<T, Semaphore>,
}
/// Receive values from the associated `Sender`.
@@ -41,16 +36,12 @@ pub struct Receiver<T> {
chan: chan::Rx<T, Semaphore>,
}
-impl<T> fmt::Debug for Receiver<T> {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Receiver")
- .field("chan", &self.chan)
- .finish()
- }
-}
-
-/// Creates a bounded mpsc channel for communicating between asynchronous tasks,
-/// returning the sender/receiver halves.
+/// Creates a bounded mpsc channel for communicating between asynchronous tasks
+/// with backpressure.
+///
+/// The channel will buffer up to the provided number of messages. Once the
+/// buffer is full, attempts to `send` new messages will wait until a message is
+/// received from the channel. The provided buffer capacity must be at least 1.
///
/// All data sent on `Sender` will become available on `Receiver` in the same
/// order as it was sent.
@@ -62,6 +53,10 @@ impl<T> fmt::Debug for Receiver<T> {
/// will return a `SendError`. Similarly, if `Sender` is disconnected while
/// trying to `recv`, the `recv` method will return a `RecvError`.
///
+/// # Panics
+///
+/// Panics if the buffer capacity is 0.
+///
/// # Examples
///
/// ```rust
@@ -69,7 +64,7 @@ impl<T> fmt::Debug for Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
-/// let (mut tx, mut rx) = mpsc::channel(100);
+/// let (tx, mut rx) = mpsc::channel(100);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -117,7 +112,7 @@ impl<T> Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(100);
+ /// let (tx, mut rx) = mpsc::channel(100);
///
/// tokio::spawn(async move {
/// tx.send("hello").await.unwrap();
@@ -135,7 +130,7 @@ impl<T> Receiver<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(100);
+ /// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
/// tx.send("world").await.unwrap();
@@ -146,15 +141,48 @@ impl<T> Receiver<T> {
/// ```
pub async fn recv(&mut self) -> Option<T> {
use crate::future::poll_fn;
-
- poll_fn(|cx| self.poll_recv(cx)).await
+ poll_fn(|cx| self.chan.recv(cx)).await
}
- #[doc(hidden)] // TODO: document
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
+ #[cfg(any(feature = "signal", feature = "process"))]
+ pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.chan.recv(cx)
}
+ /// Blocking receive to call outside of asynchronous contexts.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if called within an asynchronous execution
+ /// context.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::thread;
+ /// use tokio::runtime::Runtime;
+ /// use tokio::sync::mpsc;
+ ///
+ /// fn main() {
+ /// let (tx, mut rx) = mpsc::channel::<u8>(10);
+ ///
+ /// let sync_code = thread::spawn(move || {
+ /// assert_eq!(Some(10), rx.blocking_recv());
+ /// });
+ ///
+ /// Runtime::new()
+ /// .unwrap()
+ /// .block_on(async move {
+ /// let _ = tx.send(10).await;
+ /// });
+ /// sync_code.join().unwrap()
+ /// }
+ /// ```
+ #[cfg(feature = "sync")]
+ pub fn blocking_recv(&mut self) -> Option<T> {
+ crate::future::block_on(self.recv())
+ }
+
/// Attempts to return a pending value on this receiver without blocking.
///
/// This method will never block the caller in order to wait for data to
@@ -173,12 +201,53 @@ impl<T> Receiver<T> {
/// Closes the receiving half of a channel, without dropping it.
///
/// This prevents any further messages from being sent on the channel while
- /// still enabling the receiver to drain messages that are buffered.
+ /// still enabling the receiver to drain messages that are buffered. Any
+ /// outstanding [`Permit`] values will still be able to send messages.
+ ///
+ /// In order to guarantee no messages are dropped, after calling `close()`,
+ /// `recv()` must be called until `None` is returned.
+ ///
+ /// [`Permit`]: Permit
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(20);
+ ///
+ /// tokio::spawn(async move {
+ /// let mut i = 0;
+ /// while let Ok(permit) = tx.reserve().await {
+ /// permit.send(i);
+ /// i += 1;
+ /// }
+ /// });
+ ///
+ /// rx.close();
+ ///
+ /// while let Some(msg) = rx.recv().await {
+ /// println!("got {}", msg);
+ /// }
+ ///
+ /// // Channel closed and no messages are lost.
+ /// }
+ /// ```
pub fn close(&mut self) {
self.chan.close();
}
}
+impl<T> fmt::Debug for Receiver<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Receiver")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
+
impl<T> Unpin for Receiver<T> {}
cfg_stream! {
@@ -186,7 +255,7 @@ cfg_stream! {
type Item = T;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
- self.poll_recv(cx)
+ self.chan.recv(cx)
}
}
}
@@ -225,7 +294,7 @@ impl<T> Sender<T> {
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(1);
+ /// let (tx, mut rx) = mpsc::channel(1);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -241,18 +310,49 @@ impl<T> Sender<T> {
/// }
/// }
/// ```
- pub async fn send(&mut self, value: T) -> Result<(), SendError<T>> {
- use crate::future::poll_fn;
-
- if poll_fn(|cx| self.poll_ready(cx)).await.is_err() {
- return Err(SendError(value));
+ pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
+ match self.reserve().await {
+ Ok(permit) => {
+ permit.send(value);
+ Ok(())
+ }
+ Err(_) => Err(SendError(value)),
}
+ }
- match self.try_send(value) {
- Ok(()) => Ok(()),
- Err(TrySendError::Full(_)) => unreachable!(),
- Err(TrySendError::Closed(value)) => Err(SendError(value)),
- }
+ /// Completes when the receiver has dropped.
+ ///
+ /// This allows the producers to get notified when interest in the produced
+ /// values is canceled and immediately stop doing work.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx1, rx) = mpsc::channel::<()>(1);
+ /// let tx2 = tx1.clone();
+ /// let tx3 = tx1.clone();
+ /// let tx4 = tx1.clone();
+ /// let tx5 = tx1.clone();
+ /// tokio::spawn(async move {
+ /// drop(rx);
+ /// });
+ ///
+ /// futures::join!(
+ /// tx1.closed(),
+ /// tx2.closed(),
+ /// tx3.closed(),
+ /// tx4.closed(),
+ /// tx5.closed()
+ /// );
+ //// println!("Receiver dropped");
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ self.chan.closed().await
}
/// Attempts to immediately send a message on this `Sender`
@@ -262,9 +362,6 @@ impl<T> Sender<T> {
/// with [`send`], this function has two failure cases instead of one (one for
/// disconnection, one for a full buffer).
///
- /// This function may be paired with [`poll_ready`] in order to wait for
- /// channel capacity before trying to send a value.
- ///
/// # Errors
///
/// If the channel capacity has been reached, i.e., the channel has `n`
@@ -276,7 +373,6 @@ impl<T> Sender<T> {
/// an error. The error includes the value passed to `send`.
///
/// [`send`]: Sender::send
- /// [`poll_ready`]: Sender::poll_ready
/// [`channel`]: channel
/// [`close`]: Receiver::close
///
@@ -288,8 +384,8 @@ impl<T> Sender<T> {
/// #[tokio::main]
/// async fn main() {
/// // Create a channel with buffer size 1
- /// let (mut tx1, mut rx) = mpsc::channel(1);
- /// let mut tx2 = tx1.clone();
+ /// let (tx1, mut rx) = mpsc::channel(1);
+ /// let tx2 = tx1.clone();
///
/// tokio::spawn(async move {
/// tx1.send(1).await.unwrap();
@@ -317,8 +413,15 @@ impl<T> Sender<T> {
/// }
/// }
/// ```
- pub fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
- self.chan.try_send(message)?;
+ pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
+ match self.chan.semaphore().0.try_acquire(1) {
+ Ok(_) => {}
+ Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(message)),
+ Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(message)),
+ }
+
+ // Send the message
+ self.chan.send(message);
Ok(())
}
@@ -346,11 +449,11 @@ impl<T> Sender<T> {
///
/// ```rust
/// use tokio::sync::mpsc;
- /// use tokio::time::{delay_for, Duration};
+ /// use tokio::time::{sleep, Duration};
///
/// #[tokio::main]
/// async fn main() {
- /// let (mut tx, mut rx) = mpsc::channel(1);
+ /// let (tx, mut rx) = mpsc::channel(1);
///
/// tokio::spawn(async move {
/// for i in 0..10 {
@@ -363,117 +466,213 @@ impl<T> Sender<T> {
///
/// while let Some(i) = rx.recv().await {
/// println!("got = {}", i);
- /// delay_for(Duration::from_millis(200)).await;
+ /// sleep(Duration::from_millis(200)).await;
/// }
/// }
/// ```
#[cfg(feature = "time")]
#[cfg_attr(docsrs, doc(cfg(feature = "time")))]
pub async fn send_timeout(
- &mut self,
+ &self,
value: T,
timeout: Duration,
) -> Result<(), SendTimeoutError<T>> {
- use crate::future::poll_fn;
-
- match crate::time::timeout(timeout, poll_fn(|cx| self.poll_ready(cx))).await {
+ let permit = match crate::time::timeout(timeout, self.reserve()).await {
Err(_) => {
return Err(SendTimeoutError::Timeout(value));
}
Ok(Err(_)) => {
return Err(SendTimeoutError::Closed(value));
}
- Ok(_) => {}
- }
+ Ok(Ok(permit)) => permit,
+ };
- match self.try_send(value) {
- Ok(()) => Ok(()),
- Err(TrySendError::Full(_)) => unreachable!(),
- Err(TrySendError::Closed(value)) => Err(SendTimeoutError::Closed(value)),
- }
+ permit.send(value);
+ Ok(())
}
- /// Returns `Poll::Ready(Ok(()))` when the channel is able to accept another item.
+ /// Blocking send to call outside of asynchronous contexts.
///
- /// If the channel is full, then `Poll::Pending` is returned and the task is notified when a
- /// slot becomes available.
+ /// # Panics
///
- /// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a call to `try_send` will succeed unless
- /// the channel has since been closed. To provide this guarantee, the channel reserves one slot
- /// in the channel for the coming send. This reserved slot is not available to other `Sender`
- /// instances, so you need to be careful to not end up with deadlocks by blocking after calling
- /// `poll_ready` but before sending an element.
+ /// This function panics if called within an asynchronous execution
+ /// context.
///
- /// If, after `poll_ready` succeeds, you decide you do not wish to send an item after all, you
- /// can use [`disarm`](Sender::disarm) to release the reserved slot.
+ /// # Examples
///
- /// Until an item is sent or [`disarm`](Sender::disarm) is called, repeated calls to
- /// `poll_ready` will return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))` if channel
- /// is closed.
- pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> {
- self.chan.poll_ready(cx).map_err(|_| ClosedError::new())
+ /// ```
+ /// use std::thread;
+ /// use tokio::runtime::Runtime;
+ /// use tokio::sync::mpsc;
+ ///
+ /// fn main() {
+ /// let (tx, mut rx) = mpsc::channel::<u8>(1);
+ ///
+ /// let sync_code = thread::spawn(move || {
+ /// tx.blocking_send(10).unwrap();
+ /// });
+ ///
+ /// Runtime::new().unwrap().block_on(async move {
+ /// assert_eq!(Some(10), rx.recv().await);
+ /// });
+ /// sync_code.join().unwrap()
+ /// }
+ /// ```
+ #[cfg(feature = "sync")]
+ pub fn blocking_send(&self, value: T) -> Result<(), SendError<T>> {
+ crate::future::block_on(self.send(value))
}
- /// Undo a successful call to `poll_ready`.
+ /// Checks if the channel has been closed. This happens when the
+ /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is
+ /// called.
///
- /// Once a call to `poll_ready` returns `Poll::Ready(Ok(()))`, it holds up one slot in the
- /// channel to make room for the coming send. `disarm` allows you to give up that slot if you
- /// decide you do not wish to send an item after all. After calling `disarm`, you must call
- /// `poll_ready` until it returns `Poll::Ready(Ok(()))` before attempting to send again.
+ /// [`Receiver`]: crate::sync::mpsc::Receiver
+ /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
///
- /// Returns `false` if no slot is reserved for this sender (usually because `poll_ready` was
- /// not previously called, or did not succeed).
+ /// ```
+ /// let (tx, rx) = tokio::sync::mpsc::channel::<()>(42);
+ /// assert!(!tx.is_closed());
///
- /// # Motivation
+ /// let tx2 = tx.clone();
+ /// assert!(!tx2.is_closed());
///
- /// Since `poll_ready` takes up one of the finite number of slots in a bounded channel, callers
- /// need to send an item shortly after `poll_ready` succeeds. If they do not, idle senders may
- /// take up all the slots of the channel, and prevent active senders from getting any requests
- /// through. Consider this code that forwards from one channel to another:
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// assert!(tx2.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.chan.is_closed()
+ }
+
+ /// Wait for channel capacity. Once capacity to send one message is
+ /// available, it is reserved for the caller.
+ ///
+ /// If the channel is full, the function waits for the number of unreceived
+ /// messages to become less than the channel capacity. Capacity to send one
+ /// message is reserved for the caller. A [`Permit`] is returned to track
+ /// the reserved capacity. The [`send`] function on [`Permit`] consumes the
+ /// reserved capacity.
+ ///
+ /// Dropping [`Permit`] without sending a message releases the capacity back
+ /// to the channel.
+ ///
+ /// [`Permit`]: Permit
+ /// [`send`]: Permit::send
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(1);
+ ///
+ /// // Reserve capacity
+ /// let permit = tx.reserve().await.unwrap();
///
- /// ```rust,ignore
- /// loop {
- /// ready!(tx.poll_ready(cx))?;
- /// if let Some(item) = ready!(rx.poll_recv(cx)) {
- /// tx.try_send(item)?;
- /// } else {
- /// break;
- /// }
+ /// // Trying to send directly on the `tx` will fail due to no
+ /// // available capacity.
+ /// assert!(tx.try_send(123).is_err());
+ ///
+ /// // Sending on the permit succeeds
+ /// permit.send(456);
+ ///
+ /// // The value sent on the permit is received
+ /// assert_eq!(rx.recv().await.unwrap(), 456);
/// }
/// ```
+ pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
+ match self.chan.semaphore().0.acquire(1).await {
+ Ok(_) => {}
+ Err(_) => return Err(SendError(())),
+ }
+
+ Ok(Permit { chan: &self.chan })
+ }
+}
+
+impl<T> Clone for Sender<T> {
+ fn clone(&self) -> Self {
+ Sender {
+ chan: self.chan.clone(),
+ }
+ }
+}
+
+impl<T> fmt::Debug for Sender<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Sender")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
+
+// ===== impl Permit =====
+
+impl<T> Permit<'_, T> {
+ /// Sends a value using the reserved capacity.
+ ///
+ /// Capacity for the message has already been reserved. The message is sent
+ /// to the receiver and the permit is consumed. The operation will succeed
+ /// even if the receiver half has been closed. See [`Receiver::close`] for
+ /// more details on performing a clean shutdown.
+ ///
+ /// [`Receiver::close`]: Receiver::close
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, mut rx) = mpsc::channel(1);
+ ///
+ /// // Reserve capacity
+ /// let permit = tx.reserve().await.unwrap();
///
- /// If many such forwarders exist, and they all forward into a single (cloned) `Sender`, then
- /// any number of forwarders may be waiting for `rx.poll_recv` at the same time. While they do,
- /// they are effectively each reducing the channel's capacity by 1. If enough of these
- /// forwarders are idle, forwarders whose `rx` _do_ have elements will be unable to find a spot
- /// for them through `poll_ready`, and the system will deadlock.
- ///
- /// `disarm` solves this problem by allowing you to give up the reserved slot if you find that
- /// you have to block. We can then fix the code above by writing:
- ///
- /// ```rust,ignore
- /// loop {
- /// ready!(tx.poll_ready(cx))?;
- /// let item = rx.poll_recv(cx);
- /// if let Poll::Ready(Ok(_)) = item {
- /// // we're going to send the item below, so don't disarm
- /// } else {
- /// // give up our send slot, we won't need it for a while
- /// tx.disarm();
- /// }
- /// if let Some(item) = ready!(item) {
- /// tx.try_send(item)?;
- /// } else {
- /// break;
- /// }
+ /// // Trying to send directly on the `tx` will fail due to no
+ /// // available capacity.
+ /// assert!(tx.try_send(123).is_err());
+ ///
+ /// // Send a message on the permit
+ /// permit.send(456);
+ ///
+ /// // The value sent on the permit is received
+ /// assert_eq!(rx.recv().await.unwrap(), 456);
/// }
/// ```
- pub fn disarm(&mut self) -> bool {
- if self.chan.is_ready() {
- self.chan.disarm();
- true
- } else {
- false
+ pub fn send(self, value: T) {
+ use std::mem;
+
+ self.chan.send(value);
+
+ // Avoid the drop logic
+ mem::forget(self);
+ }
+}
+
+impl<T> Drop for Permit<'_, T> {
+ fn drop(&mut self) {
+ use chan::Semaphore;
+
+ let semaphore = self.chan.semaphore();
+
+ // Add the permit back to the semaphore
+ semaphore.add_permit();
+
+ if semaphore.is_closed() && semaphore.is_idle() {
+ self.chan.wake_rx();
}
}
}
+
+impl<T> fmt::Debug for Permit<'_, T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Permit")
+ .field("chan", &self.chan)
+ .finish()
+ }
+}
diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs
index 148ee3a..c78fb50 100644
--- a/src/sync/mpsc/chan.rs
+++ b/src/sync/mpsc/chan.rs
@@ -2,8 +2,9 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
-use crate::sync::mpsc::error::{ClosedError, TryRecvError};
-use crate::sync::mpsc::{error, list};
+use crate::sync::mpsc::error::TryRecvError;
+use crate::sync::mpsc::list;
+use crate::sync::notify::Notify;
use std::fmt;
use std::process;
@@ -12,21 +13,13 @@ use std::task::Poll::{Pending, Ready};
use std::task::{Context, Poll};
/// Channel sender
-pub(crate) struct Tx<T, S: Semaphore> {
+pub(crate) struct Tx<T, S> {
inner: Arc<Chan<T, S>>,
- permit: S::Permit,
}
-impl<T, S: Semaphore> fmt::Debug for Tx<T, S>
-where
- S::Permit: fmt::Debug,
- S: fmt::Debug,
-{
+impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Tx")
- .field("inner", &self.inner)
- .field("permit", &self.permit)
- .finish()
+ fmt.debug_struct("Tx").field("inner", &self.inner).finish()
}
}
@@ -35,70 +28,26 @@ pub(crate) struct Rx<T, S: Semaphore> {
inner: Arc<Chan<T, S>>,
}
-impl<T, S: Semaphore> fmt::Debug for Rx<T, S>
-where
- S: fmt::Debug,
-{
+impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Rx").field("inner", &self.inner).finish()
}
}
-#[derive(Debug, Eq, PartialEq)]
-pub(crate) enum TrySendError {
- Closed,
- Full,
-}
-
-impl<T> From<(T, TrySendError)> for error::SendError<T> {
- fn from(src: (T, TrySendError)) -> error::SendError<T> {
- match src.1 {
- TrySendError::Closed => error::SendError(src.0),
- TrySendError::Full => unreachable!(),
- }
- }
-}
-
-impl<T> From<(T, TrySendError)> for error::TrySendError<T> {
- fn from(src: (T, TrySendError)) -> error::TrySendError<T> {
- match src.1 {
- TrySendError::Closed => error::TrySendError::Closed(src.0),
- TrySendError::Full => error::TrySendError::Full(src.0),
- }
- }
-}
-
pub(crate) trait Semaphore {
- type Permit;
-
- fn new_permit() -> Self::Permit;
-
- /// The permit is dropped without a value being sent. In this case, the
- /// permit must be returned to the semaphore.
- fn drop_permit(&self, permit: &mut Self::Permit);
-
fn is_idle(&self) -> bool;
fn add_permit(&self);
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- permit: &mut Self::Permit,
- ) -> Poll<Result<(), ClosedError>>;
-
- fn try_acquire(&self, permit: &mut Self::Permit) -> Result<(), TrySendError>;
-
- /// A value was sent into the channel and the permit held by `tx` is
- /// dropped. In this case, the permit should not immeditely be returned to
- /// the semaphore. Instead, the permit is returnred to the semaphore once
- /// the sent value is read by the rx handle.
- fn forget(&self, permit: &mut Self::Permit);
-
fn close(&self);
+
+ fn is_closed(&self) -> bool;
}
struct Chan<T, S> {
+ /// Notifies all tasks listening for the receiver being dropped
+ notify_rx_closed: Notify,
+
/// Handle to the push half of the lock-free list.
tx: list::Tx<T>,
@@ -153,13 +102,11 @@ impl<T> fmt::Debug for RxFields<T> {
unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
-pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>)
-where
- S: Semaphore,
-{
+pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
let (tx, rx) = list::channel();
let chan = Arc::new(Chan {
+ notify_rx_closed: Notify::new(),
tx,
semaphore,
rx_waker: AtomicWaker::new(),
@@ -175,48 +122,60 @@ where
// ===== impl Tx =====
-impl<T, S> Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Tx<T, S> {
fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
- Tx {
- inner: chan,
- permit: S::new_permit(),
- }
- }
-
- pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ClosedError>> {
- self.inner.semaphore.poll_acquire(cx, &mut self.permit)
+ Tx { inner: chan }
}
- pub(crate) fn disarm(&mut self) {
- // TODO: should this error if not acquired?
- self.inner.semaphore.drop_permit(&mut self.permit)
+ pub(super) fn semaphore(&self) -> &S {
+ &self.inner.semaphore
}
/// Send a message and notify the receiver.
- pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> {
- self.inner.try_send(value, &mut self.permit)
+ pub(crate) fn send(&self, value: T) {
+ self.inner.send(value);
}
-}
-impl<T> Tx<T, (crate::sync::semaphore_ll::Semaphore, usize)> {
- pub(crate) fn is_ready(&self) -> bool {
- self.permit.is_acquired()
+ /// Wake the receive half
+ pub(crate) fn wake_rx(&self) {
+ self.inner.rx_waker.wake();
}
}
-impl<T> Tx<T, AtomicUsize> {
- pub(crate) fn send_unbounded(&self, value: T) -> Result<(), (T, TrySendError)> {
- self.inner.try_send(value, &mut ())
+impl<T, S: Semaphore> Tx<T, S> {
+ pub(crate) fn is_closed(&self) -> bool {
+ self.inner.semaphore.is_closed()
+ }
+
+ pub(crate) async fn closed(&self) {
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::Poll;
+
+ // In order to avoid a race condition, we first request a notification,
+ // **then** check the current value's version. If a new version exists,
+ // the notification request is dropped. Requesting the notification
+ // requires polling the future once.
+ let notified = self.inner.notify_rx_closed.notified();
+ pin!(notified);
+
+ // Polling the future once is guaranteed to return `Pending` as `watch`
+ // only notifies using `notify_waiters`.
+ crate::future::poll_fn(|cx| {
+ let res = Pin::new(&mut notified).poll(cx);
+ assert!(!res.is_ready());
+ Poll::Ready(())
+ })
+ .await;
+
+ if self.inner.semaphore.is_closed() {
+ return;
+ }
+ notified.await;
}
}
-impl<T, S> Clone for Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Clone for Tx<T, S> {
fn clone(&self) -> Tx<T, S> {
// Using a Relaxed ordering here is sufficient as the caller holds a
// strong ref to `self`, preventing a concurrent decrement to zero.
@@ -224,18 +183,12 @@ where
Tx {
inner: self.inner.clone(),
- permit: S::new_permit(),
}
}
}
-impl<T, S> Drop for Tx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S> Drop for Tx<T, S> {
fn drop(&mut self) {
- self.inner.semaphore.drop_permit(&mut self.permit);
-
if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
return;
}
@@ -244,16 +197,13 @@ where
self.inner.tx.close();
// Notify the receiver
- self.inner.rx_waker.wake();
+ self.wake_rx();
}
}
// ===== impl Rx =====
-impl<T, S> Rx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S: Semaphore> Rx<T, S> {
fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
Rx { inner: chan }
}
@@ -270,6 +220,7 @@ where
});
self.inner.semaphore.close();
+ self.inner.notify_rx_closed.notify_waiters();
}
/// Receive the next value
@@ -341,10 +292,7 @@ where
}
}
-impl<T, S> Drop for Rx<T, S>
-where
- S: Semaphore,
-{
+impl<T, S: Semaphore> Drop for Rx<T, S> {
fn drop(&mut self) {
use super::block::Read::Value;
@@ -362,25 +310,13 @@ where
// ===== impl Chan =====
-impl<T, S> Chan<T, S>
-where
- S: Semaphore,
-{
- fn try_send(&self, value: T, permit: &mut S::Permit) -> Result<(), (T, TrySendError)> {
- if let Err(e) = self.semaphore.try_acquire(permit) {
- return Err((value, e));
- }
-
+impl<T, S> Chan<T, S> {
+ fn send(&self, value: T) {
// Push the value
self.tx.push(value);
// Notify the rx task
self.rx_waker.wake();
-
- // Release the permit
- self.semaphore.forget(permit);
-
- Ok(())
}
}
@@ -399,72 +335,24 @@ impl<T, S> Drop for Chan<T, S> {
}
}
-use crate::sync::semaphore_ll::TryAcquireError;
-
-impl From<TryAcquireError> for TrySendError {
- fn from(src: TryAcquireError) -> TrySendError {
- if src.is_closed() {
- TrySendError::Closed
- } else if src.is_no_permits() {
- TrySendError::Full
- } else {
- unreachable!();
- }
- }
-}
-
// ===== impl Semaphore for (::Semaphore, capacity) =====
-use crate::sync::semaphore_ll::Permit;
-
-impl Semaphore for (crate::sync::semaphore_ll::Semaphore, usize) {
- type Permit = Permit;
-
- fn new_permit() -> Permit {
- Permit::new()
- }
-
- fn drop_permit(&self, permit: &mut Permit) {
- permit.release(1, &self.0);
- }
-
+impl Semaphore for (crate::sync::batch_semaphore::Semaphore, usize) {
fn add_permit(&self) {
- self.0.add_permits(1)
+ self.0.release(1)
}
fn is_idle(&self) -> bool {
self.0.available_permits() == self.1
}
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- permit: &mut Permit,
- ) -> Poll<Result<(), ClosedError>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- permit
- .poll_acquire(cx, 1, &self.0)
- .map_err(|_| ClosedError::new())
- .map(move |r| {
- coop.made_progress();
- r
- })
- }
-
- fn try_acquire(&self, permit: &mut Permit) -> Result<(), TrySendError> {
- permit.try_acquire(1, &self.0)?;
- Ok(())
- }
-
- fn forget(&self, permit: &mut Self::Permit) {
- permit.forget(1);
- }
-
fn close(&self) {
self.0.close();
}
+
+ fn is_closed(&self) -> bool {
+ self.0.is_closed()
+ }
}
// ===== impl Semaphore for AtomicUsize =====
@@ -473,12 +361,6 @@ use std::sync::atomic::Ordering::{Acquire, Release};
use std::usize;
impl Semaphore for AtomicUsize {
- type Permit = ();
-
- fn new_permit() {}
-
- fn drop_permit(&self, _permit: &mut ()) {}
-
fn add_permit(&self) {
let prev = self.fetch_sub(2, Release);
@@ -492,40 +374,11 @@ impl Semaphore for AtomicUsize {
self.load(Acquire) >> 1 == 0
}
- fn poll_acquire(
- &self,
- _cx: &mut Context<'_>,
- permit: &mut (),
- ) -> Poll<Result<(), ClosedError>> {
- Ready(self.try_acquire(permit).map_err(|_| ClosedError::new()))
- }
-
- fn try_acquire(&self, _permit: &mut ()) -> Result<(), TrySendError> {
- let mut curr = self.load(Acquire);
-
- loop {
- if curr & 1 == 1 {
- return Err(TrySendError::Closed);
- }
-
- if curr == usize::MAX ^ 1 {
- // Overflowed the ref count. There is no safe way to recover, so
- // abort the process. In practice, this should never happen.
- process::abort()
- }
-
- match self.compare_exchange(curr, curr + 2, AcqRel, Acquire) {
- Ok(_) => return Ok(()),
- Err(actual) => {
- curr = actual;
- }
- }
- }
- }
-
- fn forget(&self, _permit: &mut ()) {}
-
fn close(&self) {
self.fetch_or(1, Release);
}
+
+ fn is_closed(&self) -> bool {
+ self.load(Acquire) & 1 == 1
+ }
}
diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs
index 72c42aa..7705452 100644
--- a/src/sync/mpsc/error.rs
+++ b/src/sync/mpsc/error.rs
@@ -94,26 +94,6 @@ impl fmt::Display for TryRecvError {
impl Error for TryRecvError {}
-// ===== ClosedError =====
-
-/// Error returned by [`Sender::poll_ready`](super::Sender::poll_ready).
-#[derive(Debug)]
-pub struct ClosedError(());
-
-impl ClosedError {
- pub(crate) fn new() -> ClosedError {
- ClosedError(())
- }
-}
-
-impl fmt::Display for ClosedError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "channel closed")
- }
-}
-
-impl Error for ClosedError {}
-
cfg_time! {
// ===== SendTimeoutError =====
diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs
index 53f82a2..2f4c532 100644
--- a/src/sync/mpsc/list.rs
+++ b/src/sync/mpsc/list.rs
@@ -1,9 +1,7 @@
//! A concurrent, lock-free, FIFO list.
-use crate::loom::{
- sync::atomic::{AtomicPtr, AtomicUsize},
- thread,
-};
+use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
+use crate::loom::thread;
use crate::sync::mpsc::block::{self, Block};
use std::fmt;
diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs
index c489c9f..a2bcf83 100644
--- a/src/sync/mpsc/mod.rs
+++ b/src/sync/mpsc/mod.rs
@@ -1,23 +1,29 @@
#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
-//! A multi-producer, single-consumer queue for sending values across
+//! A multi-producer, single-consumer queue for sending values between
//! asynchronous tasks.
//!
-//! Similar to `std`, channel creation provides [`Receiver`] and [`Sender`]
-//! handles. [`Receiver`] implements `Stream` and allows a task to read values
-//! out of the channel. If there is no message to read, the current task will be
-//! notified when a new value is sent. If the channel is at capacity, the send
-//! is rejected and the task will be notified when additional capacity is
-//! available. In other words, the channel provides backpressure.
-//!
//! This module provides two variants of the channel: bounded and unbounded. The
//! bounded variant has a limit on the number of messages that the channel can
//! store, and if this limit is reached, trying to send another message will
//! wait until a message is received from the channel. An unbounded channel has
-//! an infinite capacity, so the `send` method never does any kind of sleeping.
+//! an infinite capacity, so the `send` method will always complete immediately.
//! This makes the [`UnboundedSender`] usable from both synchronous and
//! asynchronous code.
//!
+//! Similar to the `mpsc` channels provided by `std`, the channel constructor
+//! functions provide separate send and receive handles, [`Sender`] and
+//! [`Receiver`] for the bounded channel, [`UnboundedSender`] and
+//! [`UnboundedReceiver`] for the unbounded channel. Both [`Receiver`] and
+//! [`UnboundedReceiver`] implement [`Stream`] and allow a task to read
+//! values out of the channel. If there is no message to read, the current task
+//! will be notified when a new value is sent. [`Sender`] and
+//! [`UnboundedSender`] allow sending values into the channel. If the bounded
+//! channel is at capacity, the send is rejected and the task will be notified
+//! when additional capacity is available. In other words, the channel provides
+//! backpressure.
+//!
+//!
//! # Disconnection
//!
//! When all [`Sender`] handles have been dropped, it is no longer
@@ -43,11 +49,10 @@
//! are two situations to consider:
//!
//! **Bounded channel**: If you need a bounded channel, you should use a bounded
-//! Tokio `mpsc` channel for both directions of communication. To call the async
-//! [`send`][bounded-send] or [`recv`][bounded-recv] methods in sync code, you
-//! will need to use [`Handle::block_on`], which allow you to execute an async
-//! method in synchronous code. This is necessary because a bounded channel may
-//! need to wait for additional capacity to become available.
+//! Tokio `mpsc` channel for both directions of communication. Instead of calling
+//! the async [`send`][bounded-send] or [`recv`][bounded-recv] methods, in
+//! synchronous code you will need to use the [`blocking_send`][blocking-send] or
+//! [`blocking_recv`][blocking-recv] methods.
//!
//! **Unbounded channel**: You should use the kind of channel that matches where
//! the receiver is. So for sending a message _from async to sync_, you should
@@ -57,9 +62,13 @@
//!
//! [`Sender`]: crate::sync::mpsc::Sender
//! [`Receiver`]: crate::sync::mpsc::Receiver
+//! [`Stream`]: crate::stream::Stream
//! [bounded-send]: crate::sync::mpsc::Sender::send()
//! [bounded-recv]: crate::sync::mpsc::Receiver::recv()
+//! [blocking-send]: crate::sync::mpsc::Sender::blocking_send()
+//! [blocking-recv]: crate::sync::mpsc::Receiver::blocking_recv()
//! [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender
+//! [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver
//! [`Handle::block_on`]: crate::runtime::Handle::block_on()
//! [std-unbounded]: std::sync::mpsc::channel
//! [crossbeam-unbounded]: https://docs.rs/crossbeam/*/crossbeam/channel/fn.unbounded.html
@@ -67,7 +76,7 @@
pub(super) mod block;
mod bounded;
-pub use self::bounded::{channel, Receiver, Sender};
+pub use self::bounded::{channel, Permit, Receiver, Sender};
mod chan;
diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs
index 1b2288a..fe882d5 100644
--- a/src/sync/mpsc/unbounded.rs
+++ b/src/sync/mpsc/unbounded.rs
@@ -47,7 +47,7 @@ impl<T> fmt::Debug for UnboundedReceiver<T> {
}
/// Creates an unbounded mpsc channel for communicating between asynchronous
-/// tasks.
+/// tasks without backpressure.
///
/// A `send` on this channel will always succeed as long as the receive half has
/// not been closed. If the receiver falls behind, messages will be arbitrarily
@@ -73,8 +73,7 @@ impl<T> UnboundedReceiver<T> {
UnboundedReceiver { chan }
}
- #[doc(hidden)] // TODO: doc
- pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
+ fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.chan.recv(cx)
}
@@ -174,7 +173,97 @@ impl<T> UnboundedSender<T> {
/// [`close`]: UnboundedReceiver::close
/// [`UnboundedReceiver`]: UnboundedReceiver
pub fn send(&self, message: T) -> Result<(), SendError<T>> {
- self.chan.send_unbounded(message)?;
+ if !self.inc_num_messages() {
+ return Err(SendError(message));
+ }
+
+ self.chan.send(message);
Ok(())
}
+
+ fn inc_num_messages(&self) -> bool {
+ use std::process;
+ use std::sync::atomic::Ordering::{AcqRel, Acquire};
+
+ let mut curr = self.chan.semaphore().load(Acquire);
+
+ loop {
+ if curr & 1 == 1 {
+ return false;
+ }
+
+ if curr == usize::MAX ^ 1 {
+ // Overflowed the ref count. There is no safe way to recover, so
+ // abort the process. In practice, this should never happen.
+ process::abort()
+ }
+
+ match self
+ .chan
+ .semaphore()
+ .compare_exchange(curr, curr + 2, AcqRel, Acquire)
+ {
+ Ok(_) => return true,
+ Err(actual) => {
+ curr = actual;
+ }
+ }
+ }
+ }
+
+ /// Completes when the receiver has dropped.
+ ///
+ /// This allows the producers to get notified when interest in the produced
+ /// values is canceled and immediately stop doing work.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::mpsc;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx1, rx) = mpsc::unbounded_channel::<()>();
+ /// let tx2 = tx1.clone();
+ /// let tx3 = tx1.clone();
+ /// let tx4 = tx1.clone();
+ /// let tx5 = tx1.clone();
+ /// tokio::spawn(async move {
+ /// drop(rx);
+ /// });
+ ///
+ /// futures::join!(
+ /// tx1.closed(),
+ /// tx2.closed(),
+ /// tx3.closed(),
+ /// tx4.closed(),
+ /// tx5.closed()
+ /// );
+ //// println!("Receiver dropped");
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ self.chan.closed().await
+ }
+ /// Checks if the channel has been closed. This happens when the
+ /// [`UnboundedReceiver`] is dropped, or when the
+ /// [`UnboundedReceiver::close`] method is called.
+ ///
+ /// [`UnboundedReceiver`]: crate::sync::mpsc::UnboundedReceiver
+ /// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close
+ ///
+ /// ```
+ /// let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<()>();
+ /// assert!(!tx.is_closed());
+ ///
+ /// let tx2 = tx.clone();
+ /// assert!(!tx2.is_closed());
+ ///
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// assert!(tx2.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.chan.is_closed()
+ }
}
diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs
index 642058b..21e44ca 100644
--- a/src/sync/mutex.rs
+++ b/src/sync/mutex.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
+
use crate::sync::batch_semaphore as semaphore;
use std::cell::UnsafeCell;
@@ -115,7 +117,6 @@ use std::sync::Arc;
/// [`std::sync::Mutex`]: struct@std::sync::Mutex
/// [`Send`]: trait@std::marker::Send
/// [`lock`]: method@Mutex::lock
-#[derive(Debug)]
pub struct Mutex<T: ?Sized> {
s: semaphore::Semaphore,
c: UnsafeCell<T>,
@@ -220,6 +221,27 @@ impl<T: ?Sized> Mutex<T> {
}
}
+ /// Creates a new lock in an unlocked state ready for use.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Mutex;
+ ///
+ /// static LOCK: Mutex<i32> = Mutex::const_new(5);
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test)),))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(t: T) -> Self
+ where
+ T: Sized,
+ {
+ Self {
+ c: UnsafeCell::new(t),
+ s: semaphore::Semaphore::const_new(1),
+ }
+ }
+
/// Locks this mutex, causing the current task
/// to yield until the lock has been acquired.
/// When the lock has been acquired, function returns a [`MutexGuard`].
@@ -305,6 +327,30 @@ impl<T: ?Sized> Mutex<T> {
}
}
+ /// Returns a mutable reference to the underlying data.
+ ///
+ /// Since this call borrows the `Mutex` mutably, no actual locking needs to
+ /// take place -- the mutable borrow statically guarantees no locks exist.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Mutex;
+ ///
+ /// fn main() {
+ /// let mut mutex = Mutex::new(1);
+ ///
+ /// let n = mutex.get_mut();
+ /// *n = 2;
+ /// }
+ /// ```
+ pub fn get_mut(&mut self) -> &mut T {
+ unsafe {
+ // Safety: This is https://github.com/rust-lang/rust/pull/76936
+ &mut *self.c.get()
+ }
+ }
+
/// Attempts to acquire the lock, and returns [`TryLockError`] if the lock
/// is currently held somewhere else.
///
@@ -373,6 +419,20 @@ where
}
}
+impl<T> std::fmt::Debug for Mutex<T>
+where
+ T: std::fmt::Debug,
+{
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut d = f.debug_struct("Mutex");
+ match self.try_lock() {
+ Ok(inner) => d.field("data", &*inner),
+ Err(_) => d.field("data", &format_args!("<locked>")),
+ };
+ d.finish()
+ }
+}
+
// === impl MutexGuard ===
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
diff --git a/src/sync/notify.rs b/src/sync/notify.rs
index 5cb41e8..922f109 100644
--- a/src/sync/notify.rs
+++ b/src/sync/notify.rs
@@ -1,3 +1,10 @@
+// Allow `unreachable_pub` warnings when sync is not enabled
+// due to the usage of `Notify` within the `rt` feature set.
+// When this module is compiled with `sync` enabled we will warn on
+// this lint. When `rt` is enabled we use `pub(crate)` which
+// triggers this warning but it is safe to ignore in this case.
+#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
+
use crate::loom::sync::atomic::AtomicU8;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
@@ -10,6 +17,8 @@ use std::ptr::NonNull;
use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};
+type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
+
/// Notify a single task to wake up.
///
/// `Notify` provides a basic mechanism to notify a single task of an event.
@@ -17,20 +26,20 @@ use std::task::{Context, Poll, Waker};
/// another task to perform an operation.
///
/// `Notify` can be thought of as a [`Semaphore`] starting with 0 permits.
-/// [`notified().await`] waits for a permit to become available, and [`notify()`]
+/// [`notified().await`] waits for a permit to become available, and [`notify_one()`]
/// sets a permit **if there currently are no available permits**.
///
/// The synchronization details of `Notify` are similar to
/// [`thread::park`][park] and [`Thread::unpark`][unpark] from std. A [`Notify`]
/// value contains a single permit. [`notified().await`] waits for the permit to
-/// be made available, consumes the permit, and resumes. [`notify()`] sets the
+/// be made available, consumes the permit, and resumes. [`notify_one()`] sets the
/// permit, waking a pending task if there is one.
///
-/// If `notify()` is called **before** `notfied().await`, then the next call to
+/// If `notify_one()` is called **before** `notified().await`, then the next call to
/// `notified().await` will complete immediately, consuming the permit. Any
/// subsequent calls to `notified().await` will wait for a new permit.
///
-/// If `notify()` is called **multiple** times before `notified().await`, only a
+/// If `notify_one()` is called **multiple** times before `notified().await`, only a
/// **single** permit is stored. The next call to `notified().await` will
/// complete immediately, but the one after will wait for a new permit.
///
@@ -53,7 +62,7 @@ use std::task::{Context, Poll, Waker};
/// });
///
/// println!("sending notification");
-/// notify.notify();
+/// notify.notify_one();
/// }
/// ```
///
@@ -76,7 +85,7 @@ use std::task::{Context, Poll, Waker};
/// .push_back(value);
///
/// // Notify the consumer a value is available
-/// self.notify.notify();
+/// self.notify.notify_one();
/// }
///
/// pub async fn recv(&self) -> T {
@@ -96,12 +105,20 @@ use std::task::{Context, Poll, Waker};
/// [park]: std::thread::park
/// [unpark]: std::thread::Thread::unpark
/// [`notified().await`]: Notify::notified()
-/// [`notify()`]: Notify::notify()
+/// [`notify_one()`]: Notify::notify_one()
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
state: AtomicU8,
- waiters: Mutex<LinkedList<Waiter>>,
+ waiters: Mutex<WaitList>,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum NotificationType {
+ // Notification triggered by calling `notify_waiters`
+ AllWaiters,
+ // Notification triggered by calling `notify_one`
+ OneWaiter,
}
#[derive(Debug)]
@@ -113,7 +130,7 @@ struct Waiter {
waker: Option<Waker>,
/// `true` if the notification has been assigned to this waiter.
- notified: bool,
+ notified: Option<NotificationType>,
/// Should not be `Unpin`.
_p: PhantomPinned,
@@ -121,7 +138,7 @@ struct Waiter {
/// Future returned from `notified()`
#[derive(Debug)]
-struct Notified<'a> {
+pub struct Notified<'a> {
/// The `Notify` being received on.
notify: &'a Notify,
@@ -168,14 +185,38 @@ impl Notify {
}
}
+ /// Create a new `Notify`, initialized without a permit.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Notify;
+ ///
+ /// static NOTIFY: Notify = Notify::const_new();
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new() -> Notify {
+ Notify {
+ state: AtomicU8::new(0),
+ waiters: Mutex::const_new(LinkedList::new()),
+ }
+ }
+
/// Wait for a notification.
///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn notified(&self);
+ /// ```
+ ///
/// Each `Notify` value holds a single permit. If a permit is available from
- /// an earlier call to [`notify()`], then `notified().await` will complete
+ /// an earlier call to [`notify_one()`], then `notified().await` will complete
/// immediately, consuming that permit. Otherwise, `notified().await` waits
- /// for a permit to be made available by the next call to `notify()`.
+ /// for a permit to be made available by the next call to `notify_one()`.
///
- /// [`notify()`]: Notify::notify
+ /// [`notify_one()`]: Notify::notify_one
///
/// # Examples
///
@@ -194,21 +235,20 @@ impl Notify {
/// });
///
/// println!("sending notification");
- /// notify.notify();
+ /// notify.notify_one();
/// }
/// ```
- pub async fn notified(&self) {
+ pub fn notified(&self) -> Notified<'_> {
Notified {
notify: self,
state: State::Init,
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
- notified: false,
+ notified: None,
_p: PhantomPinned,
}),
}
- .await
}
/// Notifies a waiting task
@@ -216,10 +256,10 @@ impl Notify {
/// If a task is currently waiting, that task is notified. Otherwise, a
/// permit is stored in this `Notify` value and the **next** call to
/// [`notified().await`] will complete immediately consuming the permit made
- /// available by this call to `notify()`.
+ /// available by this call to `notify_one()`.
///
/// At most one permit may be stored by `Notify`. Many sequential calls to
- /// `notify` will result in a single permit being stored. The next call to
+ /// `notify_one` will result in a single permit being stored. The next call to
/// `notified().await` will complete immediately, but the one after that
/// will wait.
///
@@ -242,10 +282,10 @@ impl Notify {
/// });
///
/// println!("sending notification");
- /// notify.notify();
+ /// notify.notify_one();
/// }
/// ```
- pub fn notify(&self) {
+ pub fn notify_one(&self) {
// Load the current state
let mut curr = self.state.load(SeqCst);
@@ -266,7 +306,7 @@ impl Notify {
}
// There are waiters, the lock must be acquired to notify.
- let mut waiters = self.waiters.lock().unwrap();
+ let mut waiters = self.waiters.lock();
// The state must be reloaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
@@ -277,6 +317,45 @@ impl Notify {
waker.wake();
}
}
+
+ /// Notifies all waiting tasks
+ pub(crate) fn notify_waiters(&self) {
+ // There are waiters, the lock must be acquired to notify.
+ let mut waiters = self.waiters.lock();
+
+ // The state must be reloaded while the lock is held. The state may only
+ // transition out of WAITING while the lock is held.
+ let curr = self.state.load(SeqCst);
+
+ if let EMPTY | NOTIFIED = curr {
+ // There are no waiting tasks. In this case, no synchronization is
+ // established between `notify` and `notified().await`.
+ return;
+ }
+
+ // At this point, it is guaranteed that the state will not
+ // concurrently change, as holding the lock is required to
+ // transition **out** of `WAITING`.
+ //
+ // Get pending waiters
+ while let Some(mut waiter) = waiters.pop_back() {
+ // Safety: `waiters` lock is still held.
+ let waiter = unsafe { waiter.as_mut() };
+
+ assert!(waiter.notified.is_none());
+
+ waiter.notified = Some(NotificationType::AllWaiters);
+
+ if let Some(waker) = waiter.waker.take() {
+ waker.wake();
+ }
+ }
+
+ // All waiters have been notified, the state must be transitioned to
+ // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
+ // held, a `store` is sufficient.
+ self.state.store(EMPTY, SeqCst);
+ }
}
impl Default for Notify {
@@ -285,7 +364,7 @@ impl Default for Notify {
}
}
-fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -> Option<Waker> {
+fn notify_locked(waiters: &mut WaitList, state: &AtomicU8, curr: u8) -> Option<Waker> {
loop {
match curr {
EMPTY | NOTIFIED => {
@@ -311,9 +390,9 @@ fn notify_locked(waiters: &mut LinkedList<Waiter>, state: &AtomicU8, curr: u8) -
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
- assert!(!waiter.notified);
+ assert!(waiter.notified.is_none());
- waiter.notified = true;
+ waiter.notified = Some(NotificationType::OneWaiter);
let waker = waiter.waker.take();
if waiters.is_empty() {
@@ -373,7 +452,7 @@ impl Future for Notified<'_> {
// Acquire the lock and attempt to transition to the waiting
// state.
- let mut waiters = notify.waiters.lock().unwrap();
+ let mut waiters = notify.waiters.lock();
// Reload the state with the lock held
let mut curr = notify.state.load(SeqCst);
@@ -428,6 +507,8 @@ impl Future for Notified<'_> {
waiters.push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
*state = Waiting;
+
+ return Poll::Pending;
}
Waiting => {
// Currently in the "Waiting" state, implying the caller has
@@ -435,16 +516,16 @@ impl Future for Notified<'_> {
// `notify.waiters`). In order to access the waker fields,
// we must hold the lock.
- let waiters = notify.waiters.lock().unwrap();
+ let waiters = notify.waiters.lock();
// Safety: called while locked
let w = unsafe { &mut *waiter.get() };
- if w.notified {
+ if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
w.waker = None;
- w.notified = false;
+ w.notified = None;
*state = Done;
} else {
@@ -483,12 +564,12 @@ impl Drop for Notified<'_> {
// longer stored in the linked list.
if let Waiting = *state {
let mut notify_state = WAITING;
- let mut waiters = notify.waiters.lock().unwrap();
+ let mut waiters = notify.waiters.lock();
// `Notify.state` may be in any of the three states (Empty, Waiting,
// Notified). It doesn't actually matter what the atomic is set to
// at this point. We hold the lock and will ensure the atomic is in
- // the correct state once th elock is dropped.
+ // the correct state once the lock is dropped.
//
// Because the atomic state is not checked, at first glance, it may
// seem like this routine does not handle the case where the
@@ -516,14 +597,13 @@ impl Drop for Notified<'_> {
notify.state.store(EMPTY, SeqCst);
}
- // See if the node was notified but not received. In this case, the
- // notification must be sent to another waiter.
+ // See if the node was notified but not received. In this case, if
+ // the notification was triggered via `notify_one`, it must be sent
+ // to the next waiter.
//
// Safety: with the entry removed from the linked list, there can be
// no concurrent access to the entry
- let notified = unsafe { (*waiter.get()).notified };
-
- if notified {
+ if let Some(NotificationType::OneWaiter) = unsafe { (*waiter.get()).notified } {
if let Some(waker) = notify_locked(&mut waiters, &notify.state, notify_state) {
drop(waiters);
waker.wake();
diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs
index 17767e7..951ab71 100644
--- a/src/sync/oneshot.rs
+++ b/src/sync/oneshot.rs
@@ -124,7 +124,6 @@ struct State(usize);
/// }
/// ```
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
- #[allow(deprecated)]
let inner = Arc::new(Inner {
state: AtomicUsize::new(State::new().as_usize()),
value: UnsafeCell::new(None),
@@ -197,8 +196,7 @@ impl<T> Sender<T> {
Ok(())
}
- #[doc(hidden)] // TODO: remove
- pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
+ fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
// Keep track of task budget
let coop = ready!(crate::coop::poll_proceed(cx));
diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs
index 3d2a2f7..a84c4c1 100644
--- a/src/sync/rwlock.rs
+++ b/src/sync/rwlock.rs
@@ -1,5 +1,8 @@
-use crate::sync::batch_semaphore::{AcquireError, Semaphore};
+use crate::sync::batch_semaphore::Semaphore;
use std::cell::UnsafeCell;
+use std::fmt;
+use std::marker;
+use std::mem;
use std::ops;
#[cfg(not(loom))]
@@ -8,7 +11,7 @@ const MAX_READS: usize = 32;
#[cfg(loom)]
const MAX_READS: usize = 10;
-/// An asynchronous reader-writer lock
+/// An asynchronous reader-writer lock.
///
/// This type of lock allows a number of readers or at most one writer at any
/// point in time. The write portion of this lock typically allows modification
@@ -83,10 +86,140 @@ pub struct RwLock<T: ?Sized> {
/// [`RwLock`].
///
/// [`read`]: method@RwLock::read
-#[derive(Debug)]
+/// [`RwLock`]: struct@RwLock
pub struct RwLockReadGuard<'a, T: ?Sized> {
- permit: ReleasingPermit<'a, T>,
- lock: &'a RwLock<T>,
+ s: &'a Semaphore,
+ data: *const T,
+ marker: marker::PhantomData<&'a T>,
+}
+
+impl<'a, T> RwLockReadGuard<'a, T> {
+ /// Make a new `RwLockReadGuard` for a component of the locked data.
+ ///
+ /// This operation cannot fail as the `RwLockReadGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be
+ /// used as `RwLockReadGuard::map(...)`. A method would interfere with
+ /// methods of the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockReadGuard::map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockReadGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockReadGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// let guard = lock.read().await;
+ /// let guard = RwLockReadGuard::map(guard, |f| &f.0);
+ ///
+ /// assert_eq!(1, *guard);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn map<F, U: ?Sized>(this: Self, f: F) -> RwLockReadGuard<'a, U>
+ where
+ F: FnOnce(&T) -> &U,
+ {
+ let data = f(&*this) as *const U;
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Attempts to make a new [`RwLockReadGuard`] for a component of the
+ /// locked data. The original guard is returned if the closure returns
+ /// `None`.
+ ///
+ /// This operation cannot fail as the `RwLockReadGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be used as
+ /// `RwLockReadGuard::try_map(..)`. A method would interfere with methods of the
+ /// same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockReadGuard::try_map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockReadGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockReadGuard.html#method.try_map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockReadGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// let guard = lock.read().await;
+ /// let guard = RwLockReadGuard::try_map(guard, |f| Some(&f.0)).expect("should not fail");
+ ///
+ /// assert_eq!(1, *guard);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn try_map<F, U: ?Sized>(this: Self, f: F) -> Result<RwLockReadGuard<'a, U>, Self>
+ where
+ F: FnOnce(&T) -> Option<&U>,
+ {
+ let data = match f(&*this) {
+ Some(data) => data as *const U,
+ None => return Err(this),
+ };
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ Ok(RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ })
+ }
+}
+
+impl<'a, T: ?Sized> fmt::Debug for RwLockReadGuard<'a, T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Debug::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> fmt::Display for RwLockReadGuard<'a, T>
+where
+ T: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> Drop for RwLockReadGuard<'a, T> {
+ fn drop(&mut self) {
+ self.s.release(1);
+ }
}
/// RAII structure used to release the exclusive write access of a lock when
@@ -97,32 +230,195 @@ pub struct RwLockReadGuard<'a, T: ?Sized> {
///
/// [`write`]: method@RwLock::write
/// [`RwLock`]: struct@RwLock
-#[derive(Debug)]
pub struct RwLockWriteGuard<'a, T: ?Sized> {
- permit: ReleasingPermit<'a, T>,
- lock: &'a RwLock<T>,
+ s: &'a Semaphore,
+ data: *mut T,
+ marker: marker::PhantomData<&'a mut T>,
}
-// Wrapper arround Permit that releases on Drop
-#[derive(Debug)]
-struct ReleasingPermit<'a, T: ?Sized> {
- num_permits: u16,
- lock: &'a RwLock<T>,
+impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
+ /// Make a new `RwLockWriteGuard` for a component of the locked data.
+ ///
+ /// This operation cannot fail as the `RwLockWriteGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be used as
+ /// `RwLockWriteGuard::map(..)`. A method would interfere with methods of
+ /// the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockWriteGuard::map`] from the
+ /// [`parking_lot` crate].
+ ///
+ /// [`RwLockWriteGuard::map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockWriteGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// {
+ /// let mut mapped = RwLockWriteGuard::map(lock.write().await, |f| &mut f.0);
+ /// *mapped = 2;
+ /// }
+ ///
+ /// assert_eq!(Foo(2), *lock.read().await);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn map<F, U: ?Sized>(mut this: Self, f: F) -> RwLockWriteGuard<'a, U>
+ where
+ F: FnOnce(&mut T) -> &mut U,
+ {
+ let data = f(&mut *this) as *mut U;
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ RwLockWriteGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Attempts to make a new [`RwLockWriteGuard`] for a component of
+ /// the locked data. The original guard is returned if the closure returns
+ /// `None`.
+ ///
+ /// This operation cannot fail as the `RwLockWriteGuard` passed in already
+ /// locked the data.
+ ///
+ /// This is an associated function that needs to be
+ /// used as `RwLockWriteGuard::try_map(...)`. A method would interfere with
+ /// methods of the same name on the contents of the locked data.
+ ///
+ /// This is an asynchronous version of [`RwLockWriteGuard::try_map`] from
+ /// the [`parking_lot` crate].
+ ///
+ /// [`RwLockWriteGuard::try_map`]: https://docs.rs/lock_api/latest/lock_api/struct.RwLockWriteGuard.html#method.try_map
+ /// [`parking_lot` crate]: https://crates.io/crates/parking_lot
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{RwLock, RwLockWriteGuard};
+ ///
+ /// #[derive(Debug, Clone, Copy, PartialEq, Eq)]
+ /// struct Foo(u32);
+ ///
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = RwLock::new(Foo(1));
+ ///
+ /// {
+ /// let guard = lock.write().await;
+ /// let mut guard = RwLockWriteGuard::try_map(guard, |f| Some(&mut f.0)).expect("should not fail");
+ /// *guard = 2;
+ /// }
+ ///
+ /// assert_eq!(Foo(2), *lock.read().await);
+ /// # }
+ /// ```
+ #[inline]
+ pub fn try_map<F, U: ?Sized>(mut this: Self, f: F) -> Result<RwLockWriteGuard<'a, U>, Self>
+ where
+ F: FnOnce(&mut T) -> Option<&mut U>,
+ {
+ let data = match f(&mut *this) {
+ Some(data) => data as *mut U,
+ None => return Err(this),
+ };
+ let s = this.s;
+ // NB: Forget to avoid drop impl from being called.
+ mem::forget(this);
+ Ok(RwLockWriteGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ })
+ }
+
+ /// Atomically downgrades a write lock into a read lock without allowing
+ /// any writers to take exclusive access of the lock in the meantime.
+ ///
+ /// **Note:** This won't *necessarily* allow any additional readers to acquire
+ /// locks, since [`RwLock`] is fair and it is possible that a writer is next
+ /// in line.
+ ///
+ /// Returns an RAII guard which will drop the read access of this rwlock
+ /// when dropped.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::sync::RwLock;
+ /// # use std::sync::Arc;
+ /// #
+ /// # #[tokio::main]
+ /// # async fn main() {
+ /// let lock = Arc::new(RwLock::new(1));
+ ///
+ /// let n = lock.write().await;
+ ///
+ /// let cloned_lock = lock.clone();
+ /// let handle = tokio::spawn(async move {
+ /// *cloned_lock.write().await = 2;
+ /// });
+ ///
+ /// let n = n.downgrade();
+ /// assert_eq!(*n, 1, "downgrade is atomic");
+ ///
+ /// assert_eq!(*lock.read().await, 1, "additional readers can obtain locks");
+ ///
+ /// drop(n);
+ /// handle.await.unwrap();
+ /// assert_eq!(*lock.read().await, 2, "second writer obtained write lock");
+ /// # }
+ /// ```
+ ///
+ /// [`RwLock`]: struct@RwLock
+ pub fn downgrade(self) -> RwLockReadGuard<'a, T> {
+ let RwLockWriteGuard { s, data, .. } = self;
+
+ // Release all but one of the permits held by the write guard
+ s.release(MAX_READS - 1);
+
+ RwLockReadGuard {
+ s,
+ data,
+ marker: marker::PhantomData,
+ }
+ }
}
-impl<'a, T: ?Sized> ReleasingPermit<'a, T> {
- async fn acquire(
- lock: &'a RwLock<T>,
- num_permits: u16,
- ) -> Result<ReleasingPermit<'a, T>, AcquireError> {
- lock.s.acquire(num_permits.into()).await?;
- Ok(Self { num_permits, lock })
+impl<'a, T: ?Sized> fmt::Debug for RwLockWriteGuard<'a, T>
+where
+ T: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Debug::fmt(&**self, f)
}
}
-impl<T: ?Sized> Drop for ReleasingPermit<'_, T> {
+impl<'a, T: ?Sized> fmt::Display for RwLockWriteGuard<'a, T>
+where
+ T: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt::Display::fmt(&**self, f)
+ }
+}
+
+impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> {
fn drop(&mut self) {
- self.lock.s.release(self.num_permits as usize);
+ self.s.release(MAX_READS);
}
}
@@ -139,9 +435,11 @@ fn bounds() {
check_sync::<RwLock<u32>>();
check_unpin::<RwLock<u32>>();
+ check_send::<RwLockReadGuard<'_, u32>>();
check_sync::<RwLockReadGuard<'_, u32>>();
check_unpin::<RwLockReadGuard<'_, u32>>();
+ check_send::<RwLockWriteGuard<'_, u32>>();
check_sync::<RwLockWriteGuard<'_, u32>>();
check_unpin::<RwLockWriteGuard<'_, u32>>();
@@ -155,8 +453,17 @@ fn bounds() {
// RwLock<T>.
unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}
+// NB: These impls need to be explicit since we're storing a raw pointer.
+// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
+// `T` is `Send`.
+unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {}
unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
+// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
+// `T` is `Send` - but since this is also provides mutable access, we need to
+// make sure that `T` is `Send` since its value can be sent across thread
+// boundaries.
+unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
impl<T: ?Sized> RwLock<T> {
/// Creates a new instance of an `RwLock<T>` which is unlocked.
@@ -178,6 +485,27 @@ impl<T: ?Sized> RwLock<T> {
}
}
+ /// Creates a new instance of an `RwLock<T>` which is unlocked.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::RwLock;
+ ///
+ /// static LOCK: RwLock<i32> = RwLock::const_new(5);
+ /// ```
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(value: T) -> RwLock<T>
+ where
+ T: Sized,
+ {
+ RwLock {
+ c: UnsafeCell::new(value),
+ s: Semaphore::const_new(MAX_READS),
+ }
+ }
+
/// Locks this rwlock with shared read access, causing the current task
/// to yield until the lock has been acquired.
///
@@ -210,12 +538,16 @@ impl<T: ?Sized> RwLock<T> {
///}
/// ```
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
- let permit = ReleasingPermit::acquire(self, 1).await.unwrap_or_else(|_| {
+ self.s.acquire(1).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
});
- RwLockReadGuard { lock: self, permit }
+ RwLockReadGuard {
+ s: &self.s,
+ data: self.c.get(),
+ marker: marker::PhantomData,
+ }
}
/// Locks this rwlock with exclusive write access, causing the current task
@@ -241,15 +573,40 @@ impl<T: ?Sized> RwLock<T> {
///}
/// ```
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
- let permit = ReleasingPermit::acquire(self, MAX_READS as u16)
- .await
- .unwrap_or_else(|_| {
- // The semaphore was closed. but, we never explicitly close it, and we have a
- // handle to it through the Arc, which means that this can never happen.
- unreachable!()
- });
-
- RwLockWriteGuard { lock: self, permit }
+ self.s.acquire(MAX_READS as u32).await.unwrap_or_else(|_| {
+ // The semaphore was closed. but, we never explicitly close it, and we have a
+ // handle to it through the Arc, which means that this can never happen.
+ unreachable!()
+ });
+ RwLockWriteGuard {
+ s: &self.s,
+ data: self.c.get(),
+ marker: marker::PhantomData,
+ }
+ }
+
+ /// Returns a mutable reference to the underlying data.
+ ///
+ /// Since this call borrows the `RwLock` mutably, no actual locking needs to
+ /// take place -- the mutable borrow statically guarantees no locks exist.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::RwLock;
+ ///
+ /// fn main() {
+ /// let mut lock = RwLock::new(1);
+ ///
+ /// let n = lock.get_mut();
+ /// *n = 2;
+ /// }
+ /// ```
+ pub fn get_mut(&mut self) -> &mut T {
+ unsafe {
+ // Safety: This is https://github.com/rust-lang/rust/pull/76936
+ &mut *self.c.get()
+ }
}
/// Consumes the lock, returning the underlying data.
@@ -265,7 +622,7 @@ impl<T: ?Sized> ops::Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
- unsafe { &*self.lock.c.get() }
+ unsafe { &*self.data }
}
}
@@ -273,13 +630,13 @@ impl<T: ?Sized> ops::Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
- unsafe { &*self.lock.c.get() }
+ unsafe { &*self.data }
}
}
impl<T: ?Sized> ops::DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
- unsafe { &mut *self.lock.c.get() }
+ unsafe { &mut *self.data }
}
}
@@ -289,7 +646,7 @@ impl<T> From<T> for RwLock<T> {
}
}
-impl<T> Default for RwLock<T>
+impl<T: ?Sized> Default for RwLock<T>
where
T: Default,
{
diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs
index 2489d34..43dd976 100644
--- a/src/sync/semaphore.rs
+++ b/src/sync/semaphore.rs
@@ -1,7 +1,7 @@
use super::batch_semaphore as ll; // low level implementation
use std::sync::Arc;
-/// Counting semaphore performing asynchronous permit aquisition.
+/// Counting semaphore performing asynchronous permit acquisition.
///
/// A semaphore maintains a set of permits. Permits are used to synchronize
/// access to a shared resource. A semaphore differs from a mutex in that it
@@ -74,6 +74,15 @@ impl Semaphore {
}
}
+ /// Creates a new semaphore with the initial number of permits.
+ #[cfg(all(feature = "parking_lot", not(all(loom, test))))]
+ #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
+ pub const fn const_new(permits: usize) -> Self {
+ Self {
+ ll_sem: ll::Semaphore::const_new(permits),
+ }
+ }
+
/// Returns the current number of available permits.
pub fn available_permits(&self) -> usize {
self.ll_sem.available_permits()
@@ -114,7 +123,7 @@ impl Semaphore {
pub async fn acquire_owned(self: Arc<Self>) -> OwnedSemaphorePermit {
self.ll_sem.acquire(1).await.unwrap();
OwnedSemaphorePermit {
- sem: self.clone(),
+ sem: self,
permits: 1,
}
}
@@ -127,7 +136,7 @@ impl Semaphore {
pub fn try_acquire_owned(self: Arc<Self>) -> Result<OwnedSemaphorePermit, TryAcquireError> {
match self.ll_sem.try_acquire(1) {
Ok(_) => Ok(OwnedSemaphorePermit {
- sem: self.clone(),
+ sem: self,
permits: 1,
}),
Err(_) => Err(TryAcquireError(())),
diff --git a/src/sync/semaphore_ll.rs b/src/sync/semaphore_ll.rs
deleted file mode 100644
index 25d25ac..0000000
--- a/src/sync/semaphore_ll.rs
+++ /dev/null
@@ -1,1221 +0,0 @@
-#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
-
-//! Thread-safe, asynchronous counting semaphore.
-//!
-//! A `Semaphore` instance holds a set of permits. Permits are used to
-//! synchronize access to a shared resource.
-//!
-//! Before accessing the shared resource, callers acquire a permit from the
-//! semaphore. Once the permit is acquired, the caller then enters the critical
-//! section. If no permits are available, then acquiring the semaphore returns
-//! `Pending`. The task is woken once a permit becomes available.
-
-use crate::loom::cell::UnsafeCell;
-use crate::loom::future::AtomicWaker;
-use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize};
-use crate::loom::thread;
-
-use std::cmp;
-use std::fmt;
-use std::ptr::{self, NonNull};
-use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release};
-use std::task::Poll::{Pending, Ready};
-use std::task::{Context, Poll};
-use std::usize;
-
-/// Futures-aware semaphore.
-pub(crate) struct Semaphore {
- /// Tracks both the waiter queue tail pointer and the number of remaining
- /// permits.
- state: AtomicUsize,
-
- /// waiter queue head pointer.
- head: UnsafeCell<NonNull<Waiter>>,
-
- /// Coordinates access to the queue head.
- rx_lock: AtomicUsize,
-
- /// Stub waiter node used as part of the MPSC channel algorithm.
- stub: Box<Waiter>,
-}
-
-/// A semaphore permit
-///
-/// Tracks the lifecycle of a semaphore permit.
-///
-/// An instance of `Permit` is intended to be used with a **single** instance of
-/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore
-/// instances will result in unexpected behavior.
-///
-/// `Permit` does **not** release the permit back to the semaphore on drop. It
-/// is the user's responsibility to ensure that `Permit::release` is called
-/// before dropping the permit.
-#[derive(Debug)]
-pub(crate) struct Permit {
- waiter: Option<Box<Waiter>>,
- state: PermitState,
-}
-
-/// Error returned by `Permit::poll_acquire`.
-#[derive(Debug)]
-pub(crate) struct AcquireError(());
-
-/// Error returned by `Permit::try_acquire`.
-#[derive(Debug)]
-pub(crate) enum TryAcquireError {
- Closed,
- NoPermits,
-}
-
-/// Node used to notify the semaphore waiter when permit is available.
-#[derive(Debug)]
-struct Waiter {
- /// Stores waiter state.
- ///
- /// See `WaiterState` for more details.
- state: AtomicUsize,
-
- /// Task to wake when a permit is made available.
- waker: AtomicWaker,
-
- /// Next pointer in the queue of waiting senders.
- next: AtomicPtr<Waiter>,
-}
-
-/// Semaphore state
-///
-/// The 2 low bits track the modes.
-///
-/// - Closed
-/// - Full
-///
-/// When not full, the rest of the `usize` tracks the total number of messages
-/// in the channel. When full, the rest of the `usize` is a pointer to the tail
-/// of the "waiting senders" queue.
-#[derive(Copy, Clone)]
-struct SemState(usize);
-
-/// Permit state
-#[derive(Debug, Copy, Clone)]
-enum PermitState {
- /// Currently waiting for permits to be made available and assigned to the
- /// waiter.
- Waiting(u16),
-
- /// The number of acquired permits
- Acquired(u16),
-}
-
-/// State for an individual waker node
-#[derive(Debug, Copy, Clone)]
-struct WaiterState(usize);
-
-/// Waiter node is in the semaphore queue
-const QUEUED: usize = 0b001;
-
-/// Semaphore has been closed, no more permits will be issued.
-const CLOSED: usize = 0b10;
-
-/// The permit that owns the `Waiter` dropped.
-const DROPPED: usize = 0b100;
-
-/// Represents "one requested permit" in the waiter state
-const PERMIT_ONE: usize = 0b1000;
-
-/// Masks the waiter state to only contain bits tracking number of requested
-/// permits.
-const PERMIT_MASK: usize = usize::MAX - (PERMIT_ONE - 1);
-
-/// How much to shift a permit count to pack it into the waker state
-const PERMIT_SHIFT: u32 = PERMIT_ONE.trailing_zeros();
-
-/// Flag differentiating between available permits and waiter pointers.
-///
-/// If we assume pointers are properly aligned, then the least significant bit
-/// will always be zero. So, we use that bit to track if the value represents a
-/// number.
-const NUM_FLAG: usize = 0b01;
-
-/// Signal the semaphore is closed
-const CLOSED_FLAG: usize = 0b10;
-
-/// Maximum number of permits a semaphore can manage
-const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT;
-
-/// When representing "numbers", the state has to be shifted this much (to get
-/// rid of the flag bit).
-const NUM_SHIFT: usize = 2;
-
-// ===== impl Semaphore =====
-
-impl Semaphore {
- /// Creates a new semaphore with the initial number of permits
- ///
- /// # Panics
- ///
- /// Panics if `permits` is zero.
- pub(crate) fn new(permits: usize) -> Semaphore {
- let stub = Box::new(Waiter::new());
- let ptr = NonNull::from(&*stub);
-
- // Allocations are aligned
- debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0);
-
- let state = SemState::new(permits, &stub);
-
- Semaphore {
- state: AtomicUsize::new(state.to_usize()),
- head: UnsafeCell::new(ptr),
- rx_lock: AtomicUsize::new(0),
- stub,
- }
- }
-
- /// Returns the current number of available permits
- pub(crate) fn available_permits(&self) -> usize {
- let curr = SemState(self.state.load(Acquire));
- curr.available_permits()
- }
-
- /// Tries to acquire the requested number of permits, registering the waiter
- /// if not enough permits are available.
- fn poll_acquire(
- &self,
- cx: &mut Context<'_>,
- num_permits: u16,
- permit: &mut Permit,
- ) -> Poll<Result<(), AcquireError>> {
- self.poll_acquire2(num_permits, || {
- let waiter = permit.waiter.get_or_insert_with(|| Box::new(Waiter::new()));
-
- waiter.waker.register_by_ref(cx.waker());
-
- Some(NonNull::from(&**waiter))
- })
- }
-
- fn try_acquire(&self, num_permits: u16) -> Result<(), TryAcquireError> {
- match self.poll_acquire2(num_permits, || None) {
- Poll::Ready(res) => res.map_err(to_try_acquire),
- Poll::Pending => Err(TryAcquireError::NoPermits),
- }
- }
-
- /// Polls for a permit
- ///
- /// Tries to acquire available permits first. If unable to acquire a
- /// sufficient number of permits, the caller's waiter is pushed onto the
- /// semaphore's wait queue.
- fn poll_acquire2<F>(
- &self,
- num_permits: u16,
- mut get_waiter: F,
- ) -> Poll<Result<(), AcquireError>>
- where
- F: FnMut() -> Option<NonNull<Waiter>>,
- {
- let num_permits = num_permits as usize;
-
- // Load the current state
- let mut curr = SemState(self.state.load(Acquire));
-
- // Saves a ref to the waiter node
- let mut maybe_waiter: Option<NonNull<Waiter>> = None;
-
- /// Used in branches where we attempt to push the waiter into the wait
- /// queue but fail due to permits becoming available or the wait queue
- /// transitioning to "closed". In this case, the waiter must be
- /// transitioned back to the "idle" state.
- macro_rules! revert_to_idle {
- () => {
- if let Some(waiter) = maybe_waiter {
- unsafe { waiter.as_ref() }.revert_to_idle();
- }
- };
- }
-
- loop {
- let mut next = curr;
-
- if curr.is_closed() {
- revert_to_idle!();
- return Ready(Err(AcquireError::closed()));
- }
-
- let acquired = next.acquire_permits(num_permits, &self.stub);
-
- if !acquired {
- // There are not enough available permits to satisfy the
- // request. The permit transitions to a waiting state.
- debug_assert!(curr.waiter().is_some() || curr.available_permits() < num_permits);
-
- if let Some(waiter) = maybe_waiter.as_ref() {
- // Safety: the caller owns the waiter.
- let w = unsafe { waiter.as_ref() };
- w.set_permits_to_acquire(num_permits - curr.available_permits());
- } else {
- // Get the waiter for the permit.
- if let Some(waiter) = get_waiter() {
- // Safety: the caller owns the waiter.
- let w = unsafe { waiter.as_ref() };
-
- // If there are any currently available permits, the
- // waiter acquires those immediately and waits for the
- // remaining permits to become available.
- if !w.to_queued(num_permits - curr.available_permits()) {
- // The node is alrady queued, there is no further work
- // to do.
- return Pending;
- }
-
- maybe_waiter = Some(waiter);
- } else {
- // No waiter, this indicates the caller does not wish to
- // "wait", so there is nothing left to do.
- return Pending;
- }
- }
-
- next.set_waiter(maybe_waiter.unwrap());
- }
-
- debug_assert_ne!(curr.0, 0);
- debug_assert_ne!(next.0, 0);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- if acquired {
- // Successfully acquire permits **without** queuing the
- // waiter node. The waiter node is not currently in the
- // queue.
- revert_to_idle!();
- return Ready(Ok(()));
- } else {
- // The node is pushed into the queue, the final step is
- // to set the node's "next" pointer to return the wait
- // queue into a consistent state.
-
- let prev_waiter =
- curr.waiter().unwrap_or_else(|| NonNull::from(&*self.stub));
-
- let waiter = maybe_waiter.unwrap();
-
- // Link the nodes.
- //
- // Safety: the mpsc algorithm guarantees the old tail of
- // the queue is not removed from the queue during the
- // push process.
- unsafe {
- prev_waiter.as_ref().store_next(waiter);
- }
-
- return Pending;
- }
- }
- Err(actual) => {
- curr = SemState(actual);
- }
- }
- }
- }
-
- /// Closes the semaphore. This prevents the semaphore from issuing new
- /// permits and notifies all pending waiters.
- pub(crate) fn close(&self) {
- // Acquire the `rx_lock`, setting the "closed" flag on the lock.
- let prev = self.rx_lock.fetch_or(1, AcqRel);
-
- if prev != 0 {
- // Another thread has the lock and will be responsible for notifying
- // pending waiters.
- return;
- }
-
- self.add_permits_locked(0, true);
- }
- /// Adds `n` new permits to the semaphore.
- ///
- /// The maximum number of permits is `usize::MAX >> 3`, and this function will panic if the limit is exceeded.
- pub(crate) fn add_permits(&self, n: usize) {
- if n == 0 {
- return;
- }
-
- // TODO: Handle overflow. A panic is not sufficient, the process must
- // abort.
- let prev = self.rx_lock.fetch_add(n << 1, AcqRel);
-
- if prev != 0 {
- // Another thread has the lock and will be responsible for notifying
- // pending waiters.
- return;
- }
-
- self.add_permits_locked(n, false);
- }
-
- fn add_permits_locked(&self, mut rem: usize, mut closed: bool) {
- while rem > 0 || closed {
- if closed {
- SemState::fetch_set_closed(&self.state, AcqRel);
- }
-
- // Release the permits and notify
- self.add_permits_locked2(rem, closed);
-
- let n = rem << 1;
-
- let actual = if closed {
- let actual = self.rx_lock.fetch_sub(n | 1, AcqRel);
- closed = false;
- actual
- } else {
- let actual = self.rx_lock.fetch_sub(n, AcqRel);
- closed = actual & 1 == 1;
- actual
- };
-
- rem = (actual >> 1) - rem;
- }
- }
-
- /// Releases a specific amount of permits to the semaphore
- ///
- /// This function is called by `add_permits` after the add lock has been
- /// acquired.
- fn add_permits_locked2(&self, mut n: usize, closed: bool) {
- // If closing the semaphore, we want to drain the entire queue. The
- // number of permits being assigned doesn't matter.
- if closed {
- n = usize::MAX;
- }
-
- 'outer: while n > 0 {
- unsafe {
- let mut head = self.head.with(|head| *head);
- let mut next_ptr = head.as_ref().next.load(Acquire);
-
- let stub = self.stub();
-
- if head == stub {
- // The stub node indicates an empty queue. Any remaining
- // permits get assigned back to the semaphore.
- let next = match NonNull::new(next_ptr) {
- Some(next) => next,
- None => {
- // This loop is not part of the standard intrusive mpsc
- // channel algorithm. This is where we atomically pop
- // the last task and add `n` to the remaining capacity.
- //
- // This modification to the pop algorithm works because,
- // at this point, we have not done any work (only done
- // reading). We have a *pretty* good idea that there is
- // no concurrent pusher.
- //
- // The capacity is then atomically added by doing an
- // AcqRel CAS on `state`. The `state` cell is the
- // linchpin of the algorithm.
- //
- // By successfully CASing `head` w/ AcqRel, we ensure
- // that, if any thread was racing and entered a push, we
- // see that and abort pop, retrying as it is
- // "inconsistent".
- let mut curr = SemState::load(&self.state, Acquire);
-
- loop {
- if curr.has_waiter(&self.stub) {
- // A waiter is being added concurrently.
- // This is the MPSC queue's "inconsistent"
- // state and we must loop and try again.
- thread::yield_now();
- continue 'outer;
- }
-
- // If closing, nothing more to do.
- if closed {
- debug_assert!(curr.is_closed(), "state = {:?}", curr);
- return;
- }
-
- let mut next = curr;
- next.release_permits(n, &self.stub);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return,
- Err(actual) => {
- curr = SemState(actual);
- }
- }
- }
- }
- };
-
- self.head.with_mut(|head| *head = next);
- head = next;
- next_ptr = next.as_ref().next.load(Acquire);
- }
-
- // `head` points to a waiter assign permits to the waiter. If
- // all requested permits are satisfied, then we can continue,
- // otherwise the node stays in the wait queue.
- if !head.as_ref().assign_permits(&mut n, closed) {
- assert_eq!(n, 0);
- return;
- }
-
- if let Some(next) = NonNull::new(next_ptr) {
- self.head.with_mut(|head| *head = next);
-
- self.remove_queued(head, closed);
- continue 'outer;
- }
-
- let state = SemState::load(&self.state, Acquire);
-
- // This must always be a pointer as the wait list is not empty.
- let tail = state.waiter().unwrap();
-
- if tail != head {
- // Inconsistent
- thread::yield_now();
- continue 'outer;
- }
-
- self.push_stub(closed);
-
- next_ptr = head.as_ref().next.load(Acquire);
-
- if let Some(next) = NonNull::new(next_ptr) {
- self.head.with_mut(|head| *head = next);
-
- self.remove_queued(head, closed);
- continue 'outer;
- }
-
- // Inconsistent state, loop
- thread::yield_now();
- }
- }
- }
-
- /// The wait node has had all of its permits assigned and has been removed
- /// from the wait queue.
- ///
- /// Attempt to remove the QUEUED bit from the node. If additional permits
- /// are concurrently requested, the node must be pushed back into the wait
- /// queued.
- fn remove_queued(&self, waiter: NonNull<Waiter>, closed: bool) {
- let mut curr = WaiterState(unsafe { waiter.as_ref() }.state.load(Acquire));
-
- loop {
- if curr.is_dropped() {
- // The Permit dropped, it is on us to release the memory
- let _ = unsafe { Box::from_raw(waiter.as_ptr()) };
- return;
- }
-
- // The node is removed from the queue. We attempt to unset the
- // queued bit, but concurrently the waiter has requested more
- // permits. When the waiter requested more permits, it saw the
- // queued bit set so took no further action. This requires us to
- // push the node back into the queue.
- if curr.permits_to_acquire() > 0 {
- // More permits are requested. The waiter must be re-queued
- unsafe {
- self.push_waiter(waiter, closed);
- }
- return;
- }
-
- let mut next = curr;
- next.unset_queued();
-
- let w = unsafe { waiter.as_ref() };
-
- match w.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return,
- Err(actual) => {
- curr = WaiterState(actual);
- }
- }
- }
- }
-
- unsafe fn push_stub(&self, closed: bool) {
- self.push_waiter(self.stub(), closed);
- }
-
- unsafe fn push_waiter(&self, waiter: NonNull<Waiter>, closed: bool) {
- // Set the next pointer. This does not require an atomic operation as
- // this node is not accessible. The write will be flushed with the next
- // operation
- waiter.as_ref().next.store(ptr::null_mut(), Relaxed);
-
- // Update the tail to point to the new node. We need to see the previous
- // node in order to update the next pointer as well as release `task`
- // to any other threads calling `push`.
- let next = SemState::new_ptr(waiter, closed);
- let prev = SemState(self.state.swap(next.0, AcqRel));
-
- debug_assert_eq!(closed, prev.is_closed());
-
- // This function is only called when there are pending tasks. Because of
- // this, the state must *always* be in pointer mode.
- let prev = prev.waiter().unwrap();
-
- // No cycles plz
- debug_assert_ne!(prev, waiter);
-
- // Release `task` to the consume end.
- prev.as_ref().next.store(waiter.as_ptr(), Release);
- }
-
- fn stub(&self) -> NonNull<Waiter> {
- unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) }
- }
-}
-
-impl Drop for Semaphore {
- fn drop(&mut self) {
- self.close();
- }
-}
-
-impl fmt::Debug for Semaphore {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- fmt.debug_struct("Semaphore")
- .field("state", &SemState::load(&self.state, Relaxed))
- .field("head", &self.head.with(|ptr| ptr))
- .field("rx_lock", &self.rx_lock.load(Relaxed))
- .field("stub", &self.stub)
- .finish()
- }
-}
-
-unsafe impl Send for Semaphore {}
-unsafe impl Sync for Semaphore {}
-
-// ===== impl Permit =====
-
-impl Permit {
- /// Creates a new `Permit`.
- ///
- /// The permit begins in the "unacquired" state.
- pub(crate) fn new() -> Permit {
- use PermitState::Acquired;
-
- Permit {
- waiter: None,
- state: Acquired(0),
- }
- }
-
- /// Returns `true` if the permit has been acquired
- #[allow(dead_code)] // may be used later
- pub(crate) fn is_acquired(&self) -> bool {
- match self.state {
- PermitState::Acquired(num) if num > 0 => true,
- _ => false,
- }
- }
-
- /// Tries to acquire the permit. If no permits are available, the current task
- /// is notified once a new permit becomes available.
- pub(crate) fn poll_acquire(
- &mut self,
- cx: &mut Context<'_>,
- num_permits: u16,
- semaphore: &Semaphore,
- ) -> Poll<Result<(), AcquireError>> {
- use std::cmp::Ordering::*;
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- // There must be a waiter
- let waiter = self.waiter.as_ref().unwrap();
-
- match requested.cmp(&num_permits) {
- Less => {
- let delta = num_permits - requested;
-
- // Request additional permits. If the waiter has been
- // dequeued, it must be re-queued.
- if !waiter.try_inc_permits_to_acquire(delta as usize) {
- let waiter = NonNull::from(&**waiter);
-
- // Ignore the result. The check for
- // `permits_to_acquire()` will converge the state as
- // needed
- let _ = semaphore.poll_acquire2(delta, || Some(waiter))?;
- }
-
- self.state = Waiting(num_permits);
- }
- Greater => {
- let delta = requested - num_permits;
- let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
-
- semaphore.add_permits(to_release);
- self.state = Waiting(num_permits);
- }
- Equal => {}
- }
-
- if waiter.permits_to_acquire()? == 0 {
- self.state = Acquired(requested);
- return Ready(Ok(()));
- }
-
- waiter.waker.register_by_ref(cx.waker());
-
- if waiter.permits_to_acquire()? == 0 {
- self.state = Acquired(requested);
- return Ready(Ok(()));
- }
-
- Pending
- }
- Acquired(acquired) => {
- if acquired >= num_permits {
- Ready(Ok(()))
- } else {
- match semaphore.poll_acquire(cx, num_permits - acquired, self)? {
- Ready(()) => {
- self.state = Acquired(num_permits);
- Ready(Ok(()))
- }
- Pending => {
- self.state = Waiting(num_permits);
- Pending
- }
- }
- }
- }
- }
- }
-
- /// Tries to acquire the permit.
- pub(crate) fn try_acquire(
- &mut self,
- num_permits: u16,
- semaphore: &Semaphore,
- ) -> Result<(), TryAcquireError> {
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- // There must be a waiter
- let waiter = self.waiter.as_ref().unwrap();
-
- if requested > num_permits {
- let delta = requested - num_permits;
- let to_release = waiter.try_dec_permits_to_acquire(delta as usize);
-
- semaphore.add_permits(to_release);
- self.state = Waiting(num_permits);
- }
-
- let res = waiter.permits_to_acquire().map_err(to_try_acquire)?;
-
- if res == 0 {
- if requested < num_permits {
- // Try to acquire the additional permits
- semaphore.try_acquire(num_permits - requested)?;
- }
-
- self.state = Acquired(num_permits);
- Ok(())
- } else {
- Err(TryAcquireError::NoPermits)
- }
- }
- Acquired(acquired) => {
- if acquired < num_permits {
- semaphore.try_acquire(num_permits - acquired)?;
- self.state = Acquired(num_permits);
- }
-
- Ok(())
- }
- }
- }
-
- /// Releases a permit back to the semaphore
- pub(crate) fn release(&mut self, n: u16, semaphore: &Semaphore) {
- let n = self.forget(n);
- semaphore.add_permits(n as usize);
- }
-
- /// Forgets the permit **without** releasing it back to the semaphore.
- ///
- /// After calling `forget`, `poll_acquire` is able to acquire new permit
- /// from the semaphore.
- ///
- /// Repeatedly calling `forget` without associated calls to `add_permit`
- /// will result in the semaphore losing all permits.
- ///
- /// Will forget **at most** the number of acquired permits. This number is
- /// returned.
- pub(crate) fn forget(&mut self, n: u16) -> u16 {
- use PermitState::*;
-
- match self.state {
- Waiting(requested) => {
- let n = cmp::min(n, requested);
-
- // Decrement
- let acquired = self
- .waiter
- .as_ref()
- .unwrap()
- .try_dec_permits_to_acquire(n as usize) as u16;
-
- if n == requested {
- self.state = Acquired(0);
- } else if acquired == requested - n {
- self.state = Waiting(acquired);
- } else {
- self.state = Waiting(requested - n);
- }
-
- acquired
- }
- Acquired(acquired) => {
- let n = cmp::min(n, acquired);
- self.state = Acquired(acquired - n);
- n
- }
- }
- }
-}
-
-impl Default for Permit {
- fn default() -> Self {
- Self::new()
- }
-}
-
-impl Drop for Permit {
- fn drop(&mut self) {
- if let Some(waiter) = self.waiter.take() {
- // Set the dropped flag
- let state = WaiterState(waiter.state.fetch_or(DROPPED, AcqRel));
-
- if state.is_queued() {
- // The waiter is stored in the queue. The semaphore will drop it
- std::mem::forget(waiter);
- }
- }
- }
-}
-
-// ===== impl AcquireError ====
-
-impl AcquireError {
- fn closed() -> AcquireError {
- AcquireError(())
- }
-}
-
-fn to_try_acquire(_: AcquireError) -> TryAcquireError {
- TryAcquireError::Closed
-}
-
-impl fmt::Display for AcquireError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(fmt, "semaphore closed")
- }
-}
-
-impl std::error::Error for AcquireError {}
-
-// ===== impl TryAcquireError =====
-
-impl TryAcquireError {
- /// Returns `true` if the error was caused by a closed semaphore.
- pub(crate) fn is_closed(&self) -> bool {
- match self {
- TryAcquireError::Closed => true,
- _ => false,
- }
- }
-
- /// Returns `true` if the error was caused by calling `try_acquire` on a
- /// semaphore with no available permits.
- pub(crate) fn is_no_permits(&self) -> bool {
- match self {
- TryAcquireError::NoPermits => true,
- _ => false,
- }
- }
-}
-
-impl fmt::Display for TryAcquireError {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- TryAcquireError::Closed => write!(fmt, "semaphore closed"),
- TryAcquireError::NoPermits => write!(fmt, "no permits available"),
- }
- }
-}
-
-impl std::error::Error for TryAcquireError {}
-
-// ===== impl Waiter =====
-
-impl Waiter {
- fn new() -> Waiter {
- Waiter {
- state: AtomicUsize::new(0),
- waker: AtomicWaker::new(),
- next: AtomicPtr::new(ptr::null_mut()),
- }
- }
-
- fn permits_to_acquire(&self) -> Result<usize, AcquireError> {
- let state = WaiterState(self.state.load(Acquire));
-
- if state.is_closed() {
- Err(AcquireError(()))
- } else {
- Ok(state.permits_to_acquire())
- }
- }
-
- /// Only increments the number of permits *if* the waiter is currently
- /// queued.
- ///
- /// # Returns
- ///
- /// `true` if the number of permits to acquire has been incremented. `false`
- /// otherwise. On `false`, the caller should use `Semaphore::poll_acquire`.
- fn try_inc_permits_to_acquire(&self, n: usize) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- if !curr.is_queued() {
- assert_eq!(0, curr.permits_to_acquire());
- return false;
- }
-
- let mut next = curr;
- next.set_permits_to_acquire(n + curr.permits_to_acquire());
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return true,
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Try to decrement the number of permits to acquire. This returns the
- /// actual number of permits that were decremented. The delta betweeen `n`
- /// and the return has been assigned to the permit and the caller must
- /// assign these back to the semaphore.
- fn try_dec_permits_to_acquire(&self, n: usize) -> usize {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- if !curr.is_queued() {
- assert_eq!(0, curr.permits_to_acquire());
- }
-
- let delta = cmp::min(n, curr.permits_to_acquire());
- let rem = curr.permits_to_acquire() - delta;
-
- let mut next = curr;
- next.set_permits_to_acquire(rem);
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => return n - delta,
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Store the number of remaining permits needed to satisfy the waiter and
- /// transition to the "QUEUED" state.
- ///
- /// # Returns
- ///
- /// `true` if the `QUEUED` bit was set as part of the transition.
- fn to_queued(&self, num_permits: usize) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- // The waiter should **not** be waiting for any permits.
- debug_assert_eq!(curr.permits_to_acquire(), 0);
-
- loop {
- let mut next = curr;
- next.set_permits_to_acquire(num_permits);
- next.set_queued();
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- if curr.is_queued() {
- return false;
- } else {
- // Make sure the next pointer is null
- self.next.store(ptr::null_mut(), Relaxed);
- return true;
- }
- }
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- /// Set the number of permits to acquire.
- ///
- /// This function is only called when the waiter is being inserted into the
- /// wait queue. Because of this, there are no concurrent threads that can
- /// modify the state and using `store` is safe.
- fn set_permits_to_acquire(&self, num_permits: usize) {
- debug_assert!(WaiterState(self.state.load(Acquire)).is_queued());
-
- let mut state = WaiterState(QUEUED);
- state.set_permits_to_acquire(num_permits);
-
- self.state.store(state.0, Release);
- }
-
- /// Assign permits to the waiter.
- ///
- /// Returns `true` if the waiter should be removed from the queue
- fn assign_permits(&self, n: &mut usize, closed: bool) -> bool {
- let mut curr = WaiterState(self.state.load(Acquire));
-
- loop {
- let mut next = curr;
-
- // Number of permits to assign to this waiter
- let assign = cmp::min(curr.permits_to_acquire(), *n);
-
- // Assign the permits
- next.set_permits_to_acquire(curr.permits_to_acquire() - assign);
-
- if closed {
- next.set_closed();
- }
-
- match self.state.compare_exchange(curr.0, next.0, AcqRel, Acquire) {
- Ok(_) => {
- // Update `n`
- *n -= assign;
-
- if next.permits_to_acquire() == 0 {
- if curr.permits_to_acquire() > 0 {
- self.waker.wake();
- }
-
- return true;
- } else {
- return false;
- }
- }
- Err(actual) => curr = WaiterState(actual),
- }
- }
- }
-
- fn revert_to_idle(&self) {
- // An idle node is not waiting on any permits
- self.state.store(0, Relaxed);
- }
-
- fn store_next(&self, next: NonNull<Waiter>) {
- self.next.store(next.as_ptr(), Release);
- }
-}
-
-// ===== impl SemState =====
-
-impl SemState {
- /// Returns a new default `State` value.
- fn new(permits: usize, stub: &Waiter) -> SemState {
- assert!(permits <= MAX_PERMITS);
-
- if permits > 0 {
- SemState((permits << NUM_SHIFT) | NUM_FLAG)
- } else {
- SemState(stub as *const _ as usize)
- }
- }
-
- /// Returns a `State` tracking `ptr` as the tail of the queue.
- fn new_ptr(tail: NonNull<Waiter>, closed: bool) -> SemState {
- let mut val = tail.as_ptr() as usize;
-
- if closed {
- val |= CLOSED_FLAG;
- }
-
- SemState(val)
- }
-
- /// Returns the amount of remaining capacity
- fn available_permits(self) -> usize {
- if !self.has_available_permits() {
- return 0;
- }
-
- self.0 >> NUM_SHIFT
- }
-
- /// Returns `true` if the state has permits that can be claimed by a waiter.
- fn has_available_permits(self) -> bool {
- self.0 & NUM_FLAG == NUM_FLAG
- }
-
- fn has_waiter(self, stub: &Waiter) -> bool {
- !self.has_available_permits() && !self.is_stub(stub)
- }
-
- /// Tries to atomically acquire specified number of permits.
- ///
- /// # Return
- ///
- /// Returns `true` if the specified number of permits were acquired, `false`
- /// otherwise. Returning false does not mean that there are no more
- /// available permits.
- fn acquire_permits(&mut self, num: usize, stub: &Waiter) -> bool {
- debug_assert!(num > 0);
-
- if self.available_permits() < num {
- return false;
- }
-
- debug_assert!(self.waiter().is_none());
-
- self.0 -= num << NUM_SHIFT;
-
- if self.0 == NUM_FLAG {
- // Set the state to the stub pointer.
- self.0 = stub as *const _ as usize;
- }
-
- true
- }
-
- /// Releases permits
- ///
- /// Returns `true` if the permits were accepted.
- fn release_permits(&mut self, permits: usize, stub: &Waiter) {
- debug_assert!(permits > 0);
-
- if self.is_stub(stub) {
- self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG);
- return;
- }
-
- debug_assert!(self.has_available_permits());
-
- self.0 += permits << NUM_SHIFT;
- }
-
- fn is_waiter(self) -> bool {
- self.0 & NUM_FLAG == 0
- }
-
- /// Returns the waiter, if one is set.
- fn waiter(self) -> Option<NonNull<Waiter>> {
- if self.is_waiter() {
- let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored");
-
- Some(waiter)
- } else {
- None
- }
- }
-
- /// Assumes `self` represents a pointer
- fn as_ptr(self) -> *mut Waiter {
- (self.0 & !CLOSED_FLAG) as *mut Waiter
- }
-
- /// Sets to a pointer to a waiter.
- ///
- /// This can only be done from the full state.
- fn set_waiter(&mut self, waiter: NonNull<Waiter>) {
- let waiter = waiter.as_ptr() as usize;
- debug_assert!(!self.is_closed());
-
- self.0 = waiter;
- }
-
- fn is_stub(self, stub: &Waiter) -> bool {
- self.as_ptr() as usize == stub as *const _ as usize
- }
-
- /// Loads the state from an AtomicUsize.
- fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState {
- let value = cell.load(ordering);
- SemState(value)
- }
-
- fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState {
- let value = cell.fetch_or(CLOSED_FLAG, ordering);
- SemState(value)
- }
-
- fn is_closed(self) -> bool {
- self.0 & CLOSED_FLAG == CLOSED_FLAG
- }
-
- /// Converts the state into a `usize` representation.
- fn to_usize(self) -> usize {
- self.0
- }
-}
-
-impl fmt::Debug for SemState {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- let mut fmt = fmt.debug_struct("SemState");
-
- if self.is_waiter() {
- fmt.field("state", &"<waiter>");
- } else {
- fmt.field("permits", &self.available_permits());
- }
-
- fmt.finish()
- }
-}
-
-// ===== impl WaiterState =====
-
-impl WaiterState {
- fn permits_to_acquire(self) -> usize {
- self.0 >> PERMIT_SHIFT
- }
-
- fn set_permits_to_acquire(&mut self, val: usize) {
- self.0 = (val << PERMIT_SHIFT) | (self.0 & !PERMIT_MASK)
- }
-
- fn is_queued(self) -> bool {
- self.0 & QUEUED == QUEUED
- }
-
- fn set_queued(&mut self) {
- self.0 |= QUEUED;
- }
-
- fn is_closed(self) -> bool {
- self.0 & CLOSED == CLOSED
- }
-
- fn set_closed(&mut self) {
- self.0 |= CLOSED;
- }
-
- fn unset_queued(&mut self) {
- assert!(self.is_queued());
- self.0 -= QUEUED;
- }
-
- fn is_dropped(self) -> bool {
- self.0 & DROPPED == DROPPED
- }
-}
diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs
index 73b1745..ae4cac7 100644
--- a/src/sync/task/atomic_waker.rs
+++ b/src/sync/task/atomic_waker.rs
@@ -141,13 +141,12 @@ impl AtomicWaker {
}
}
+ /*
/// Registers the current waker to be notified on calls to `wake`.
- ///
- /// This is the same as calling `register_task` with `task::current()`.
- #[cfg(feature = "io-driver")]
pub(crate) fn register(&self, waker: Waker) {
self.do_register(waker);
}
+ */
/// Registers the provided waker to be notified on calls to `wake`.
///
diff --git a/src/sync/tests/loom_broadcast.rs b/src/sync/tests/loom_broadcast.rs
index da12fb9..4b1f034 100644
--- a/src/sync/tests/loom_broadcast.rs
+++ b/src/sync/tests/loom_broadcast.rs
@@ -1,5 +1,5 @@
use crate::sync::broadcast;
-use crate::sync::broadcast::RecvError::{Closed, Lagged};
+use crate::sync::broadcast::error::RecvError::{Closed, Lagged};
use loom::future::block_on;
use loom::sync::Arc;
diff --git a/src/sync/tests/loom_cancellation_token.rs b/src/sync/tests/loom_cancellation_token.rs
deleted file mode 100644
index e9c9f3d..0000000
--- a/src/sync/tests/loom_cancellation_token.rs
+++ /dev/null
@@ -1,155 +0,0 @@
-use crate::sync::CancellationToken;
-
-use loom::{future::block_on, thread};
-use tokio_test::assert_ok;
-
-#[test]
-fn cancel_token() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
-
- let th1 = thread::spawn(move || {
- block_on(async {
- token1.cancelled().await;
- });
- });
-
- let th2 = thread::spawn(move || {
- token.cancel();
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- });
-}
-
-#[test]
-fn cancel_with_child() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
- let token2 = token.clone();
- let child_token = token.child_token();
-
- let th1 = thread::spawn(move || {
- block_on(async {
- token1.cancelled().await;
- });
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- block_on(async {
- child_token.cancelled().await;
- });
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_token_no_child() {
- loom::model(|| {
- let token = CancellationToken::new();
- let token1 = token.clone();
- let token2 = token.clone();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- drop(token2);
- });
-
- let th3 = thread::spawn(move || {
- drop(token);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_token_with_childs() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let child_token1 = token1.child_token();
- let child_token2 = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- drop(child_token1);
- });
-
- let th3 = thread::spawn(move || {
- drop(child_token2);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn drop_and_cancel_token() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let token2 = token1.clone();
- let child_token = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- drop(child_token);
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
-
-#[test]
-fn cancel_parent_and_child() {
- loom::model(|| {
- let token1 = CancellationToken::new();
- let token2 = token1.clone();
- let child_token = token1.child_token();
-
- let th1 = thread::spawn(move || {
- drop(token1);
- });
-
- let th2 = thread::spawn(move || {
- token2.cancel();
- });
-
- let th3 = thread::spawn(move || {
- child_token.cancel();
- });
-
- assert_ok!(th1.join());
- assert_ok!(th2.join());
- assert_ok!(th3.join());
- });
-}
diff --git a/src/sync/tests/loom_mpsc.rs b/src/sync/tests/loom_mpsc.rs
index 6a1a6ab..c12313b 100644
--- a/src/sync/tests/loom_mpsc.rs
+++ b/src/sync/tests/loom_mpsc.rs
@@ -2,22 +2,24 @@ use crate::sync::mpsc;
use futures::future::poll_fn;
use loom::future::block_on;
+use loom::sync::Arc;
use loom::thread;
+use tokio_test::assert_ok;
#[test]
fn closing_tx() {
loom::model(|| {
- let (mut tx, mut rx) = mpsc::channel(16);
+ let (tx, mut rx) = mpsc::channel(16);
thread::spawn(move || {
tx.try_send(()).unwrap();
drop(tx);
});
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_some());
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
@@ -32,15 +34,70 @@ fn closing_unbounded_tx() {
drop(tx);
});
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_some());
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
#[test]
+fn closing_bounded_rx() {
+ loom::model(|| {
+ let (tx1, rx) = mpsc::channel::<()>(16);
+ let tx2 = tx1.clone();
+ thread::spawn(move || {
+ drop(rx);
+ });
+
+ block_on(tx1.closed());
+ block_on(tx2.closed());
+ });
+}
+
+#[test]
+fn closing_and_sending() {
+ loom::model(|| {
+ let (tx1, mut rx) = mpsc::channel::<()>(16);
+ let tx1 = Arc::new(tx1);
+ let tx2 = tx1.clone();
+
+ let th1 = thread::spawn(move || {
+ tx1.try_send(()).unwrap();
+ });
+
+ let th2 = thread::spawn(move || {
+ block_on(tx2.closed());
+ });
+
+ let th3 = thread::spawn(move || {
+ let v = block_on(rx.recv());
+ assert!(v.is_some());
+ drop(rx);
+ });
+
+ assert_ok!(th1.join());
+ assert_ok!(th2.join());
+ assert_ok!(th3.join());
+ });
+}
+
+#[test]
+fn closing_unbounded_rx() {
+ loom::model(|| {
+ let (tx1, rx) = mpsc::unbounded_channel::<()>();
+ let tx2 = tx1.clone();
+ thread::spawn(move || {
+ drop(rx);
+ });
+
+ block_on(tx1.closed());
+ block_on(tx2.closed());
+ });
+}
+
+#[test]
fn dropping_tx() {
loom::model(|| {
let (tx, mut rx) = mpsc::channel::<()>(16);
@@ -53,7 +110,7 @@ fn dropping_tx() {
}
drop(tx);
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
@@ -71,7 +128,7 @@ fn dropping_unbounded_tx() {
}
drop(tx);
- let v = block_on(poll_fn(|cx| rx.poll_recv(cx)));
+ let v = block_on(rx.recv());
assert!(v.is_none());
});
}
diff --git a/src/sync/tests/loom_notify.rs b/src/sync/tests/loom_notify.rs
index 60981d4..79a5bf8 100644
--- a/src/sync/tests/loom_notify.rs
+++ b/src/sync/tests/loom_notify.rs
@@ -16,7 +16,7 @@ fn notify_one() {
});
});
- tx.notify();
+ tx.notify_one();
th.join().unwrap();
});
}
@@ -34,12 +34,12 @@ fn notify_multi() {
ths.push(thread::spawn(move || {
block_on(async {
notify.notified().await;
- notify.notify();
+ notify.notify_one();
})
}));
}
- notify.notify();
+ notify.notify_one();
for th in ths.drain(..) {
th.join().unwrap();
@@ -67,7 +67,7 @@ fn notify_drop() {
block_on(poll_fn(|cx| {
if recv.as_mut().poll(cx).is_ready() {
- rx1.notify();
+ rx1.notify_one();
}
Poll::Ready(())
}));
@@ -77,12 +77,12 @@ fn notify_drop() {
block_on(async {
rx2.notified().await;
// Trigger second notification
- rx2.notify();
+ rx2.notify_one();
rx2.notified().await;
});
});
- notify.notify();
+ notify.notify_one();
th1.join().unwrap();
th2.join().unwrap();
diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs
index dfa7459..9729cfb 100644
--- a/src/sync/tests/loom_oneshot.rs
+++ b/src/sync/tests/loom_oneshot.rs
@@ -75,8 +75,10 @@ impl Future for OnClose<'_> {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
- let res = self.get_mut().tx.poll_closed(cx);
- Ready(res.is_ready())
+ let fut = self.get_mut().tx.closed();
+ crate::pin!(fut);
+
+ Ready(fut.poll(cx).is_ready())
}
}
diff --git a/src/sync/tests/loom_semaphore_ll.rs b/src/sync/tests/loom_semaphore_ll.rs
deleted file mode 100644
index b5e5efb..0000000
--- a/src/sync/tests/loom_semaphore_ll.rs
+++ /dev/null
@@ -1,192 +0,0 @@
-use crate::sync::semaphore_ll::*;
-
-use futures::future::poll_fn;
-use loom::future::block_on;
-use loom::thread;
-use std::future::Future;
-use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
-use std::sync::atomic::Ordering::SeqCst;
-use std::sync::Arc;
-use std::task::Poll::Ready;
-use std::task::{Context, Poll};
-
-#[test]
-fn basic_usage() {
- const NUM: usize = 2;
-
- struct Actor {
- waiter: Permit,
- shared: Arc<Shared>,
- }
-
- struct Shared {
- semaphore: Semaphore,
- active: AtomicUsize,
- }
-
- impl Future for Actor {
- type Output = ();
-
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
- let me = &mut *self;
-
- ready!(me.waiter.poll_acquire(cx, 1, &me.shared.semaphore)).unwrap();
-
- let actual = me.shared.active.fetch_add(1, SeqCst);
- assert!(actual <= NUM - 1);
-
- let actual = me.shared.active.fetch_sub(1, SeqCst);
- assert!(actual <= NUM);
-
- me.waiter.release(1, &me.shared.semaphore);
-
- Ready(())
- }
- }
-
- loom::model(|| {
- let shared = Arc::new(Shared {
- semaphore: Semaphore::new(NUM),
- active: AtomicUsize::new(0),
- });
-
- for _ in 0..NUM {
- let shared = shared.clone();
-
- thread::spawn(move || {
- block_on(Actor {
- waiter: Permit::new(),
- shared,
- });
- });
- }
-
- block_on(Actor {
- waiter: Permit::new(),
- shared,
- });
- });
-}
-
-#[test]
-fn release() {
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- {
- let semaphore = semaphore.clone();
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap();
-
- permit.release(1, &semaphore);
- });
- }
-
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| permit.poll_acquire(cx, 1, &semaphore))).unwrap();
-
- permit.release(1, &semaphore);
- });
-}
-
-#[test]
-fn basic_closing() {
- const NUM: usize = 2;
-
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- for _ in 0..NUM {
- let semaphore = semaphore.clone();
-
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- for _ in 0..2 {
- block_on(poll_fn(|cx| {
- permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ())
- }))?;
-
- permit.release(1, &semaphore);
- }
-
- Ok::<(), ()>(())
- });
- }
-
- semaphore.close();
- });
-}
-
-#[test]
-fn concurrent_close() {
- const NUM: usize = 3;
-
- loom::model(|| {
- let semaphore = Arc::new(Semaphore::new(1));
-
- for _ in 0..NUM {
- let semaphore = semaphore.clone();
-
- thread::spawn(move || {
- let mut permit = Permit::new();
-
- block_on(poll_fn(|cx| {
- permit.poll_acquire(cx, 1, &semaphore).map_err(|_| ())
- }))?;
-
- permit.release(1, &semaphore);
-
- semaphore.close();
-
- Ok::<(), ()>(())
- });
- }
- });
-}
-
-#[test]
-fn batch() {
- let mut b = loom::model::Builder::new();
- b.preemption_bound = Some(1);
-
- b.check(|| {
- let semaphore = Arc::new(Semaphore::new(10));
- let active = Arc::new(AtomicUsize::new(0));
- let mut ths = vec![];
-
- for _ in 0..2 {
- let semaphore = semaphore.clone();
- let active = active.clone();
-
- ths.push(thread::spawn(move || {
- let mut permit = Permit::new();
-
- for n in &[4, 10, 8] {
- block_on(poll_fn(|cx| permit.poll_acquire(cx, *n, &semaphore))).unwrap();
-
- active.fetch_add(*n as usize, SeqCst);
-
- let num_active = active.load(SeqCst);
- assert!(num_active <= 10);
-
- thread::yield_now();
-
- active.fetch_sub(*n as usize, SeqCst);
-
- permit.release(*n, &semaphore);
- }
- }));
- }
-
- for th in ths.into_iter() {
- th.join().unwrap();
- }
-
- assert_eq!(10, semaphore.available_permits());
- });
-}
diff --git a/src/sync/tests/loom_watch.rs b/src/sync/tests/loom_watch.rs
new file mode 100644
index 0000000..c575b5b
--- /dev/null
+++ b/src/sync/tests/loom_watch.rs
@@ -0,0 +1,36 @@
+use crate::sync::watch;
+
+use loom::future::block_on;
+use loom::thread;
+
+#[test]
+fn smoke() {
+ loom::model(|| {
+ let (tx, mut rx1) = watch::channel(1);
+ let mut rx2 = rx1.clone();
+ let mut rx3 = rx1.clone();
+ let mut rx4 = rx1.clone();
+ let mut rx5 = rx1.clone();
+
+ let th = thread::spawn(move || {
+ tx.send(2).unwrap();
+ });
+
+ block_on(rx1.changed()).unwrap();
+ assert_eq!(*rx1.borrow(), 2);
+
+ block_on(rx2.changed()).unwrap();
+ assert_eq!(*rx2.borrow(), 2);
+
+ block_on(rx3.changed()).unwrap();
+ assert_eq!(*rx3.borrow(), 2);
+
+ block_on(rx4.changed()).unwrap();
+ assert_eq!(*rx4.borrow(), 2);
+
+ block_on(rx5.changed()).unwrap();
+ assert_eq!(*rx5.borrow(), 2);
+
+ th.join().unwrap();
+ })
+}
diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs
index 6ba8c1f..a78be6f 100644
--- a/src/sync/tests/mod.rs
+++ b/src/sync/tests/mod.rs
@@ -1,18 +1,15 @@
cfg_not_loom! {
mod atomic_waker;
- mod semaphore_ll;
mod semaphore_batch;
}
cfg_loom! {
mod loom_atomic_waker;
mod loom_broadcast;
- #[cfg(tokio_unstable)]
- mod loom_cancellation_token;
mod loom_list;
mod loom_mpsc;
mod loom_notify;
mod loom_oneshot;
mod loom_semaphore_batch;
- mod loom_semaphore_ll;
+ mod loom_watch;
}
diff --git a/src/sync/tests/semaphore_ll.rs b/src/sync/tests/semaphore_ll.rs
deleted file mode 100644
index bfb0757..0000000
--- a/src/sync/tests/semaphore_ll.rs
+++ /dev/null
@@ -1,470 +0,0 @@
-use crate::sync::semaphore_ll::{Permit, Semaphore};
-use tokio_test::*;
-
-#[test]
-fn poll_acquire_one_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
- assert!(!permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn poll_acquire_many_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
- assert!(!permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling for a larger number of permits acquires more
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 8, &s)));
- assert_eq!(s.available_permits(), 92);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn try_acquire_one_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = Permit::new();
- assert!(!permit.is_acquired());
-
- assert_ok!(permit.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ok!(permit.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 99);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn try_acquire_many_available() {
- let s = Semaphore::new(100);
- assert_eq!(s.available_permits(), 100);
-
- // Polling for a permit succeeds immediately
- let mut permit = Permit::new();
- assert!(!permit.is_acquired());
-
- assert_ok!(permit.try_acquire(5, &s));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-
- // Polling again on the same waiter does not claim a new permit
- assert_ok!(permit.try_acquire(5, &s));
- assert_eq!(s.available_permits(), 95);
- assert!(permit.is_acquired());
-}
-
-#[test]
-fn poll_acquire_one_unavailable() {
- let s = Semaphore::new(1);
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
-
- // Acquire the first permit
- assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 0);
-
- permit_2.enter(|cx, mut p| {
- // Try to acquire the second permit
- assert_pending!(p.poll_acquire(cx, 1, &s));
- });
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 0);
- assert!(permit_2.is_woken());
- assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn forget_acquired() {
- let s = Semaphore::new(1);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-
- permit.forget(1);
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn forget_waiting() {
- let s = Semaphore::new(0);
-
- // Polling for a permit succeeds immediately
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-
- permit.forget(1);
-
- s.add_permits(1);
-
- assert!(!permit.is_woken());
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn poll_acquire_many_unavailable() {
- let s = Semaphore::new(5);
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
- let mut permit_3 = task::spawn(Permit::new());
-
- // Acquire the first permit
- assert_ready_ok!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 4);
-
- permit_2.enter(|cx, mut p| {
- // Try to acquire the second permit
- assert_pending!(p.poll_acquire(cx, 5, &s));
- });
-
- assert_eq!(s.available_permits(), 0);
-
- permit_3.enter(|cx, mut p| {
- // Try to acquire the third permit
- assert_pending!(p.poll_acquire(cx, 3, &s));
- });
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 0);
- assert!(permit_2.is_woken());
- assert_ready_ok!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
-
- assert!(!permit_3.is_woken());
- assert_eq!(s.available_permits(), 0);
-
- permit_2.release(1, &s);
- assert!(!permit_3.is_woken());
- assert_eq!(s.available_permits(), 0);
-
- permit_2.release(2, &s);
- assert!(permit_3.is_woken());
-
- assert_ready_ok!(permit_3.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn try_acquire_one_unavailable() {
- let s = Semaphore::new(1);
-
- let mut permit_1 = Permit::new();
- let mut permit_2 = Permit::new();
-
- // Acquire the first permit
- assert_ok!(permit_1.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 0);
-
- assert_err!(permit_2.try_acquire(1, &s));
-
- permit_1.release(1, &s);
-
- assert_eq!(s.available_permits(), 1);
- assert_ok!(permit_2.try_acquire(1, &s));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-}
-
-#[test]
-fn try_acquire_many_unavailable() {
- let s = Semaphore::new(5);
-
- let mut permit_1 = Permit::new();
- let mut permit_2 = Permit::new();
-
- // Acquire the first permit
- assert_ok!(permit_1.try_acquire(1, &s));
- assert_eq!(s.available_permits(), 4);
-
- assert_err!(permit_2.try_acquire(5, &s));
-
- permit_1.release(1, &s);
- assert_eq!(s.available_permits(), 5);
-
- assert_ok!(permit_2.try_acquire(5, &s));
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 1);
-
- permit_2.release(1, &s);
- assert_eq!(s.available_permits(), 2);
-}
-
-#[test]
-fn poll_acquire_one_zero_permits() {
- let s = Semaphore::new(0);
- assert_eq!(s.available_permits(), 0);
-
- let mut permit = task::spawn(Permit::new());
-
- // Try to acquire the permit
- permit.enter(|cx, mut p| {
- assert_pending!(p.poll_acquire(cx, 1, &s));
- });
-
- s.add_permits(1);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-#[should_panic]
-fn validates_max_permits() {
- use std::usize;
- Semaphore::new((usize::MAX >> 2) + 1);
-}
-
-#[test]
-fn close_semaphore_prevents_acquire() {
- let s = Semaphore::new(5);
- s.close();
-
- assert_eq!(5, s.available_permits());
-
- let mut permit_1 = task::spawn(Permit::new());
- let mut permit_2 = task::spawn(Permit::new());
-
- assert_ready_err!(permit_1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(5, s.available_permits());
-
- assert_ready_err!(permit_2.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(5, s.available_permits());
-}
-
-#[test]
-fn close_semaphore_notifies_permit1() {
- let s = Semaphore::new(0);
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.close();
-
- assert!(permit.is_woken());
- assert_ready_err!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-fn close_semaphore_notifies_permit2() {
- let s = Semaphore::new(2);
-
- let mut permit1 = task::spawn(Permit::new());
- let mut permit2 = task::spawn(Permit::new());
- let mut permit3 = task::spawn(Permit::new());
- let mut permit4 = task::spawn(Permit::new());
-
- // Acquire a couple of permits
- assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_pending!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_pending!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.close();
-
- assert!(permit3.is_woken());
- assert!(permit4.is_woken());
-
- assert_ready_err!(permit3.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_ready_err!(permit4.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(0, s.available_permits());
-
- permit1.release(1, &s);
-
- assert_eq!(1, s.available_permits());
-
- assert_ready_err!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- permit2.release(1, &s);
-
- assert_eq!(2, s.available_permits());
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_before_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-
- s.add_permits(1);
- assert!(!permit.is_woken());
-
- s.add_permits(1);
- assert!(permit.is_woken());
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn try_acquire_additional_permits_while_waiting_before_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- assert_err!(permit.enter(|_, mut p| p.try_acquire(3, &s)));
-
- s.add_permits(1);
- assert!(permit.is_woken());
-
- assert_ok!(permit.enter(|_, mut p| p.try_acquire(2, &s)));
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_after_assigned_success() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- s.add_permits(2);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-}
-
-#[test]
-fn poll_acquire_additional_permits_while_waiting_after_assigned_requeue() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
-
- s.add_permits(2);
-
- assert!(permit.is_woken());
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s)));
-
- s.add_permits(1);
-
- assert!(permit.is_woken());
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 4, &s)));
-}
-
-#[test]
-fn poll_acquire_fewer_permits_while_waiting() {
- let s = Semaphore::new(1);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(s.available_permits(), 0);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn poll_acquire_fewer_permits_after_assigned() {
- let s = Semaphore::new(1);
-
- let mut permit1 = task::spawn(Permit::new());
- let mut permit2 = task::spawn(Permit::new());
-
- assert_pending!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 5, &s)));
- assert_eq!(s.available_permits(), 0);
-
- assert_pending!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- s.add_permits(4);
- assert!(permit1.is_woken());
- assert!(!permit2.is_woken());
-
- assert_ready_ok!(permit1.enter(|cx, mut p| p.poll_acquire(cx, 3, &s)));
-
- assert!(permit2.is_woken());
- assert_eq!(s.available_permits(), 1);
-
- assert_ready_ok!(permit2.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-}
-
-#[test]
-fn forget_partial_1() {
- let s = Semaphore::new(0);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- s.add_permits(1);
-
- assert_eq!(0, s.available_permits());
-
- permit.release(1, &s);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 1, &s)));
-
- assert_eq!(s.available_permits(), 0);
-}
-
-#[test]
-fn forget_partial_2() {
- let s = Semaphore::new(0);
-
- let mut permit = task::spawn(Permit::new());
-
- assert_pending!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- s.add_permits(1);
-
- assert_eq!(0, s.available_permits());
-
- permit.release(1, &s);
-
- s.add_permits(1);
-
- assert_ready_ok!(permit.enter(|cx, mut p| p.poll_acquire(cx, 2, &s)));
- assert_eq!(s.available_permits(), 0);
-}
diff --git a/src/sync/watch.rs b/src/sync/watch.rs
index 13033d9..ec73832 100644
--- a/src/sync/watch.rs
+++ b/src/sync/watch.rs
@@ -6,13 +6,11 @@
//!
//! # Usage
//!
-//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are
-//! the producer and sender halves of the channel. The channel is
-//! created with an initial value. [`Receiver::recv`] will always
-//! be ready upon creation and will yield either this initial value or
-//! the latest value that has been sent by `Sender`.
-//!
-//! Calls to [`Receiver::recv`] will always yield the latest value.
+//! [`channel`] returns a [`Sender`] / [`Receiver`] pair. These are the producer
+//! and sender halves of the channel. The channel is created with an initial
+//! value. The **latest** value stored in the channel is accessed with
+//! [`Receiver::borrow()`]. Awaiting [`Receiver::changed()`] waits for a new
+//! value to sent by the [`Sender`] half.
//!
//! # Examples
//!
@@ -23,21 +21,21 @@
//! let (tx, mut rx) = watch::channel("hello");
//!
//! tokio::spawn(async move {
-//! while let Some(value) = rx.recv().await {
-//! println!("received = {:?}", value);
+//! while rx.changed().await.is_ok() {
+//! println!("received = {:?}", *rx.borrow());
//! }
//! });
//!
-//! tx.broadcast("world")?;
+//! tx.send("world")?;
//! # Ok(())
//! # }
//! ```
//!
//! # Closing
//!
-//! [`Sender::closed`] allows the producer to detect when all [`Receiver`]
-//! handles have been dropped. This indicates that there is no further interest
-//! in the values being produced and work can be stopped.
+//! [`Sender::is_closed`] and [`Sender::closed`] allow the producer to detect
+//! when all [`Receiver`] handles have been dropped. This indicates that there
+//! is no further interest in the values being produced and work can be stopped.
//!
//! # Thread safety
//!
@@ -47,20 +45,18 @@
//!
//! [`Sender`]: crate::sync::watch::Sender
//! [`Receiver`]: crate::sync::watch::Receiver
-//! [`Receiver::recv`]: crate::sync::watch::Receiver::recv
+//! [`Receiver::changed()`]: crate::sync::watch::Receiver::changed
+//! [`Receiver::borrow()`]: crate::sync::watch::Receiver::borrow
//! [`channel`]: crate::sync::watch::channel
+//! [`Sender::is_closed`]: crate::sync::watch::Sender::is_closed
//! [`Sender::closed`]: crate::sync::watch::Sender::closed
-use crate::future::poll_fn;
-use crate::sync::task::AtomicWaker;
+use crate::sync::Notify;
-use fnv::FnvHashSet;
use std::ops;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
-use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak};
-use std::task::Poll::{Pending, Ready};
-use std::task::{Context, Poll};
+use std::sync::{Arc, RwLock, RwLockReadGuard};
/// Receives values from the associated [`Sender`](struct@Sender).
///
@@ -70,8 +66,8 @@ pub struct Receiver<T> {
/// Pointer to the shared state
shared: Arc<Shared<T>>,
- /// Pointer to the watcher's internal state
- inner: Watcher,
+ /// Last observed version
+ version: usize,
}
/// Sends values to the associated [`Receiver`](struct@Receiver).
@@ -79,7 +75,7 @@ pub struct Receiver<T> {
/// Instances are created by the [`channel`](fn@channel) function.
#[derive(Debug)]
pub struct Sender<T> {
- shared: Weak<Shared<T>>,
+ shared: Arc<Shared<T>>,
}
/// Returns a reference to the inner value
@@ -92,6 +88,27 @@ pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
}
+#[derive(Debug)]
+struct Shared<T> {
+ /// The most recent value
+ value: RwLock<T>,
+
+ /// The current version
+ ///
+ /// The lowest bit represents a "closed" state. The rest of the bits
+ /// represent the current version.
+ version: AtomicUsize,
+
+ /// Tracks the number of `Receiver` instances
+ ref_count_rx: AtomicUsize,
+
+ /// Notifies waiting receivers that the value changed.
+ notify_rx: Notify,
+
+ /// Notifies any task listening for `Receiver` dropped events
+ notify_tx: Notify,
+}
+
pub mod error {
//! Watch error types
@@ -112,37 +129,20 @@ pub mod error {
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
-}
-
-#[derive(Debug)]
-struct Shared<T> {
- /// The most recent value
- value: RwLock<T>,
- /// The current version
- ///
- /// The lowest bit represents a "closed" state. The rest of the bits
- /// represent the current version.
- version: AtomicUsize,
-
- /// All watchers
- watchers: Mutex<Watchers>,
-
- /// Task to notify when all watchers drop
- cancel: AtomicWaker,
-}
+ /// Error produced when receiving a change notification.
+ #[derive(Debug)]
+ pub struct RecvError(pub(super) ());
-type Watchers = FnvHashSet<Watcher>;
+ // ===== impl RecvError =====
-/// The watcher's ID is based on the Arc's pointer.
-#[derive(Clone, Debug)]
-struct Watcher(Arc<WatchInner>);
+ impl fmt::Display for RecvError {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(fmt, "channel closed")
+ }
+ }
-#[derive(Debug)]
-struct WatchInner {
- /// Last observed version
- version: AtomicUsize,
- waker: AtomicWaker,
+ impl std::error::Error for RecvError {}
}
const CLOSED: usize = 1;
@@ -162,41 +162,32 @@ const CLOSED: usize = 1;
/// let (tx, mut rx) = watch::channel("hello");
///
/// tokio::spawn(async move {
-/// while let Some(value) = rx.recv().await {
-/// println!("received = {:?}", value);
+/// while rx.changed().await.is_ok() {
+/// println!("received = {:?}", *rx.borrow());
/// }
/// });
///
-/// tx.broadcast("world")?;
+/// tx.send("world")?;
/// # Ok(())
/// # }
/// ```
///
/// [`Sender`]: struct@Sender
/// [`Receiver`]: struct@Receiver
-pub fn channel<T: Clone>(init: T) -> (Sender<T>, Receiver<T>) {
- const VERSION_0: usize = 0;
- const VERSION_1: usize = 2;
-
- // We don't start knowing VERSION_1
- let inner = Watcher::new_version(VERSION_0);
-
- // Insert the watcher
- let mut watchers = FnvHashSet::with_capacity_and_hasher(0, Default::default());
- watchers.insert(inner.clone());
-
+pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
- version: AtomicUsize::new(VERSION_1),
- watchers: Mutex::new(watchers),
- cancel: AtomicWaker::new(),
+ version: AtomicUsize::new(0),
+ ref_count_rx: AtomicUsize::new(1),
+ notify_rx: Notify::new(),
+ notify_tx: Notify::new(),
});
let tx = Sender {
- shared: Arc::downgrade(&shared),
+ shared: shared.clone(),
};
- let rx = Receiver { shared, inner };
+ let rx = Receiver { shared, version: 0 };
(tx, rx)
}
@@ -221,39 +212,13 @@ impl<T> Receiver<T> {
Ref { inner }
}
- // TODO: document
- #[doc(hidden)]
- pub fn poll_recv_ref<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll<Option<Ref<'a, T>>> {
- // Make sure the task is up to date
- self.inner.waker.register_by_ref(cx.waker());
-
- let state = self.shared.version.load(SeqCst);
- let version = state & !CLOSED;
-
- if self.inner.version.swap(version, Relaxed) != version {
- let inner = self.shared.value.read().unwrap();
-
- return Ready(Some(Ref { inner }));
- }
-
- if CLOSED == state & CLOSED {
- // The `Store` handle has been dropped.
- return Ready(None);
- }
-
- Pending
- }
-}
-
-impl<T: Clone> Receiver<T> {
- /// Attempts to clone the latest value sent via the channel.
+ /// Wait for a change notification
///
- /// If this is the first time the function is called on a `Receiver`
- /// instance, then the function completes immediately with the **current**
- /// value held by the channel. On the next call, the function waits until
- /// a new value is sent in the channel.
+ /// Returns when a new value has been sent by the [`Sender`] since the last
+ /// time `changed()` was called. When the `Sender` half is dropped, `Err` is
+ /// returned.
///
- /// `None` is returned if the `Sender` half is dropped.
+ /// [`Sender`]: struct@Sender
///
/// # Examples
///
@@ -264,118 +229,170 @@ impl<T: Clone> Receiver<T> {
/// async fn main() {
/// let (tx, mut rx) = watch::channel("hello");
///
- /// let v = rx.recv().await.unwrap();
- /// assert_eq!(v, "hello");
- ///
/// tokio::spawn(async move {
- /// tx.broadcast("goodbye").unwrap();
+ /// tx.send("goodbye").unwrap();
/// });
///
- /// // Waits for the new task to spawn and send the value.
- /// let v = rx.recv().await.unwrap();
- /// assert_eq!(v, "goodbye");
+ /// assert!(rx.changed().await.is_ok());
+ /// assert_eq!(*rx.borrow(), "goodbye");
///
- /// let v = rx.recv().await;
- /// assert!(v.is_none());
+ /// // The `tx` handle has been dropped
+ /// assert!(rx.changed().await.is_err());
/// }
/// ```
- pub async fn recv(&mut self) -> Option<T> {
- poll_fn(|cx| {
- let v_ref = ready!(self.poll_recv_ref(cx));
- Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone()))
+ pub async fn changed(&mut self) -> Result<(), error::RecvError> {
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::task::Poll;
+
+ // In order to avoid a race condition, we first request a notification,
+ // **then** check the current value's version. If a new version exists,
+ // the notification request is dropped. Requesting the notification
+ // requires polling the future once.
+ let notified = self.shared.notify_rx.notified();
+ pin!(notified);
+
+ // Polling the future once is guaranteed to return `Pending` as `watch`
+ // only notifies using `notify_waiters`.
+ crate::future::poll_fn(|cx| {
+ let res = Pin::new(&mut notified).poll(cx);
+ assert!(!res.is_ready());
+ Poll::Ready(())
})
- .await
+ .await;
+
+ if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
+ return ret;
+ }
+
+ notified.await;
+
+ maybe_changed(&self.shared, &mut self.version)
+ .expect("[bug] failed to observe change after notificaton.")
}
}
-#[cfg(feature = "stream")]
-impl<T: Clone> crate::stream::Stream for Receiver<T> {
- type Item = T;
-
- fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
- let v_ref = ready!(self.poll_recv_ref(cx));
+fn maybe_changed<T>(
+ shared: &Shared<T>,
+ version: &mut usize,
+) -> Option<Result<(), error::RecvError>> {
+ // Load the version from the state
+ let state = shared.version.load(SeqCst);
+ let new_version = state & !CLOSED;
+
+ if *version != new_version {
+ // Observe the new version and return
+ *version = new_version;
+ return Some(Ok(()));
+ }
- Poll::Ready(v_ref.map(|v_ref| (*v_ref).clone()))
+ if CLOSED == state & CLOSED {
+ // All receivers have dropped.
+ return Some(Err(error::RecvError(())));
}
+
+ None
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
- let ver = self.inner.version.load(Relaxed);
- let inner = Watcher::new_version(ver);
+ let version = self.version;
let shared = self.shared.clone();
- shared.watchers.lock().unwrap().insert(inner.clone());
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ shared.ref_count_rx.fetch_add(1, Relaxed);
- Receiver { shared, inner }
+ Receiver { version, shared }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
- self.shared.watchers.lock().unwrap().remove(&self.inner);
+ // No synchronization necessary as this is only used as a counter and
+ // not memory access.
+ if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
+ // This is the last `Receiver` handle, tasks waiting on `Sender::closed()`
+ self.shared.notify_tx.notify_waiters();
+ }
}
}
impl<T> Sender<T> {
- /// Broadcasts a new value via the channel, notifying all receivers.
- pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> {
- let shared = match self.shared.upgrade() {
- Some(shared) => shared,
- // All `Watch` handles have been canceled
- None => return Err(error::SendError { inner: value }),
- };
-
- // Replace the value
- {
- let mut lock = shared.value.write().unwrap();
- *lock = value;
+ /// Sends a new value via the channel, notifying all receivers.
+ pub fn send(&self, value: T) -> Result<(), error::SendError<T>> {
+ // This is pretty much only useful as a hint anyway, so synchronization isn't critical.
+ if 0 == self.shared.ref_count_rx.load(Relaxed) {
+ return Err(error::SendError { inner: value });
}
+ *self.shared.value.write().unwrap() = value;
+
// Update the version. 2 is used so that the CLOSED bit is not set.
- shared.version.fetch_add(2, SeqCst);
+ self.shared.version.fetch_add(2, SeqCst);
// Notify all watchers
- notify_all(&*shared);
+ self.shared.notify_rx.notify_waiters();
Ok(())
}
+ /// Checks if the channel has been closed. This happens when all receivers
+ /// have dropped.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// let (tx, rx) = tokio::sync::watch::channel(());
+ /// assert!(!tx.is_closed());
+ ///
+ /// drop(rx);
+ /// assert!(tx.is_closed());
+ /// ```
+ pub fn is_closed(&self) -> bool {
+ self.shared.ref_count_rx.load(Relaxed) == 0
+ }
+
/// Completes when all receivers have dropped.
///
/// This allows the producer to get notified when interest in the produced
/// values is canceled and immediately stop doing work.
- pub async fn closed(&mut self) {
- poll_fn(|cx| self.poll_close(cx)).await
- }
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::watch;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx) = watch::channel("hello");
+ ///
+ /// tokio::spawn(async move {
+ /// // use `rx`
+ /// drop(rx);
+ /// });
+ ///
+ /// // Waits for `rx` to drop
+ /// tx.closed().await;
+ /// println!("the `rx` handles dropped")
+ /// }
+ /// ```
+ pub async fn closed(&self) {
+ let notified = self.shared.notify_tx.notified();
- fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> {
- match self.shared.upgrade() {
- Some(shared) => {
- shared.cancel.register_by_ref(cx.waker());
- Pending
- }
- None => Ready(()),
+ if self.shared.ref_count_rx.load(Relaxed) == 0 {
+ return;
}
- }
-}
-
-/// Notifies all watchers of a change
-fn notify_all<T>(shared: &Shared<T>) {
- let watchers = shared.watchers.lock().unwrap();
- for watcher in watchers.iter() {
- // Notify the task
- watcher.waker.wake();
+ notified.await;
+ debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed));
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
- if let Some(shared) = self.shared.upgrade() {
- shared.version.fetch_or(CLOSED, SeqCst);
- notify_all(&*shared);
- }
+ self.shared.version.fetch_or(CLOSED, SeqCst);
+ self.shared.notify_rx.notify_waiters();
}
}
@@ -388,44 +405,3 @@ impl<T> ops::Deref for Ref<'_, T> {
self.inner.deref()
}
}
-
-// ===== impl Shared =====
-
-impl<T> Drop for Shared<T> {
- fn drop(&mut self) {
- self.cancel.wake();
- }
-}
-
-// ===== impl Watcher =====
-
-impl Watcher {
- fn new_version(version: usize) -> Self {
- Watcher(Arc::new(WatchInner {
- version: AtomicUsize::new(version),
- waker: AtomicWaker::new(),
- }))
- }
-}
-
-impl std::cmp::PartialEq for Watcher {
- fn eq(&self, other: &Watcher) -> bool {
- Arc::ptr_eq(&self.0, &other.0)
- }
-}
-
-impl std::cmp::Eq for Watcher {}
-
-impl std::hash::Hash for Watcher {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- (&*self.0 as *const WatchInner).hash(state)
- }
-}
-
-impl std::ops::Deref for Watcher {
- type Target = WatchInner;
-
- fn deref(&self) -> &Self::Target {
- &self.0
- }
-}
diff --git a/src/task/blocking.rs b/src/task/blocking.rs
index ed60f4c..fc6632b 100644
--- a/src/task/blocking.rs
+++ b/src/task/blocking.rs
@@ -1,6 +1,6 @@
use crate::task::JoinHandle;
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
/// Runs the provided blocking function on the current thread without
/// blocking the executor.
///
@@ -17,7 +17,7 @@ cfg_rt_threaded! {
/// using the [`join!`] macro. To avoid this issue, use [`spawn_blocking`]
/// instead.
///
- /// Note that this function can only be used on the [threaded scheduler].
+ /// Note that this function can only be used when using the `multi_thread` runtime.
///
/// Code running behind `block_in_place` cannot be cancelled. When you shut
/// down the executor, it will wait indefinitely for all blocking operations
@@ -27,7 +27,6 @@ cfg_rt_threaded! {
/// returns.
///
/// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code
- /// [threaded scheduler]: fn@crate::runtime::Builder::threaded_scheduler
/// [`spawn_blocking`]: fn@crate::task::spawn_blocking
/// [`join!`]: macro@join
/// [`thread::spawn`]: fn@std::thread::spawn
@@ -44,7 +43,6 @@ cfg_rt_threaded! {
/// });
/// # }
/// ```
- #[cfg_attr(docsrs, doc(cfg(feature = "blocking")))]
pub fn block_in_place<F, R>(f: F) -> R
where
F: FnOnce() -> R,
@@ -53,80 +51,63 @@ cfg_rt_threaded! {
}
}
-cfg_blocking! {
- /// Runs the provided closure on a thread where blocking is acceptable.
- ///
- /// In general, issuing a blocking call or performing a lot of compute in a
- /// future without yielding is not okay, as it may prevent the executor from
- /// driving other futures forward. This function runs the provided closure
- /// on a thread dedicated to blocking operations. See the [CPU-bound tasks
- /// and blocking code][blocking] section for more information.
- ///
- /// Tokio will spawn more blocking threads when they are requested through
- /// this function until the upper limit configured on the [`Builder`] is
- /// reached. This limit is very large by default, because `spawn_blocking` is
- /// often used for various kinds of IO operations that cannot be performed
- /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you
- /// should keep this large upper limit in mind; to run your CPU-bound
- /// computations on only a few threads, you should use a separate thread
- /// pool such as [rayon] rather than configuring the number of blocking
- /// threads.
- ///
- /// This function is intended for non-async operations that eventually
- /// finish on their own. If you want to spawn an ordinary thread, you should
- /// use [`thread::spawn`] instead.
- ///
- /// Closures spawned using `spawn_blocking` cannot be cancelled. When you
- /// shut down the executor, it will wait indefinitely for all blocking
- /// operations to finish. You can use [`shutdown_timeout`] to stop waiting
- /// for them after a certain timeout. Be aware that this will still not
- /// cancel the tasks — they are simply allowed to keep running after the
- /// method returns.
- ///
- /// Note that if you are using the [basic scheduler], this function will
- /// still spawn additional threads for blocking operations. The basic
- /// scheduler's single thread is only used for asynchronous code.
- ///
- /// [`Builder`]: struct@crate::runtime::Builder
- /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code
- /// [rayon]: https://docs.rs/rayon
- /// [basic scheduler]: fn@crate::runtime::Builder::basic_scheduler
- /// [`thread::spawn`]: fn@std::thread::spawn
- /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::task;
- ///
- /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{
- /// let res = task::spawn_blocking(move || {
- /// // do some compute-heavy work or call synchronous code
- /// "done computing"
- /// }).await?;
- ///
- /// assert_eq!(res, "done computing");
- /// # Ok(())
- /// # }
- /// ```
- pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
- where
- F: FnOnce() -> R + Send + 'static,
- R: Send + 'static,
- {
- #[cfg(feature = "tracing")]
- let f = {
- let span = tracing::trace_span!(
- target: "tokio::task",
- "task",
- kind = %"blocking",
- function = %std::any::type_name::<F>(),
- );
- move || {
- let _g = span.enter();
- f()
- }
- };
- crate::runtime::spawn_blocking(f)
- }
+/// Runs the provided closure on a thread where blocking is acceptable.
+///
+/// In general, issuing a blocking call or performing a lot of compute in a
+/// future without yielding is problematic, as it may prevent the executor from
+/// driving other futures forward. This function runs the provided closure on a
+/// thread dedicated to blocking operations. See the [CPU-bound tasks and
+/// blocking code][blocking] section for more information.
+///
+/// Tokio will spawn more blocking threads when they are requested through this
+/// function until the upper limit configured on the [`Builder`] is reached.
+/// This limit is very large by default, because `spawn_blocking` is often used
+/// for various kinds of IO operations that cannot be performed asynchronously.
+/// When you run CPU-bound code using `spawn_blocking`, you should keep this
+/// large upper limit in mind. When running many CPU-bound computations, a
+/// semaphore or some other synchronization primitive should be used to limit
+/// the number of computation executed in parallel. Specialized CPU-bound
+/// executors, such as [rayon], may also be a good fit.
+///
+/// This function is intended for non-async operations that eventually finish on
+/// their own. If you want to spawn an ordinary thread, you should use
+/// [`thread::spawn`] instead.
+///
+/// Closures spawned using `spawn_blocking` cannot be cancelled. When you shut
+/// down the executor, it will wait indefinitely for all blocking operations to
+/// finish. You can use [`shutdown_timeout`] to stop waiting for them after a
+/// certain timeout. Be aware that this will still not cancel the tasks — they
+/// are simply allowed to keep running after the method returns.
+///
+/// Note that if you are using the single threaded runtime, this function will
+/// still spawn additional threads for blocking operations. The basic
+/// scheduler's single thread is only used for asynchronous code.
+///
+/// [`Builder`]: struct@crate::runtime::Builder
+/// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code
+/// [rayon]: https://docs.rs/rayon
+/// [`thread::spawn`]: fn@std::thread::spawn
+/// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout
+///
+/// # Examples
+///
+/// ```
+/// use tokio::task;
+///
+/// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{
+/// let res = task::spawn_blocking(move || {
+/// // do some compute-heavy work or call synchronous code
+/// "done computing"
+/// }).await?;
+///
+/// assert_eq!(res, "done computing");
+/// # Ok(())
+/// # }
+/// ```
+pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
+where
+ F: FnOnce() -> R + Send + 'static,
+ R: Send + 'static,
+{
+ crate::runtime::spawn_blocking(f)
}
diff --git a/src/task/local.rs b/src/task/local.rs
index 3c409ed..5896126 100644
--- a/src/task/local.rs
+++ b/src/task/local.rs
@@ -1,7 +1,7 @@
//! Runs `!Send` futures on the current thread.
use crate::runtime::task::{self, JoinHandle, Task};
use crate::sync::AtomicWaker;
-use crate::util::linked_list::LinkedList;
+use crate::util::linked_list::{Link, LinkedList};
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
@@ -14,7 +14,7 @@ use std::task::Poll;
use pin_project_lite::pin_project;
-cfg_rt_util! {
+cfg_rt! {
/// A set of tasks which are executed on the same thread.
///
/// In some cases, it is necessary to run one or more futures that do not
@@ -95,7 +95,7 @@ cfg_rt_util! {
/// });
///
/// local.spawn_local(async move {
- /// time::delay_for(time::Duration::from_millis(100)).await;
+ /// time::sleep(time::Duration::from_millis(100)).await;
/// println!("goodbye {}", unsend_data)
/// });
///
@@ -132,7 +132,7 @@ struct Context {
struct Tasks {
/// Collection of all active tasks spawned onto this executor.
- owned: LinkedList<Task<Arc<Shared>>>,
+ owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>,
/// Local run queue sender and receiver.
queue: VecDeque<task::Notified<Arc<Shared>>>,
@@ -158,7 +158,7 @@ pin_project! {
scoped_thread_local!(static CURRENT: Context);
-cfg_rt_util! {
+cfg_rt! {
/// Spawns a `!Send` future on the local task set.
///
/// The spawned future will be run on the same thread that called `spawn_local.`
@@ -312,9 +312,9 @@ impl LocalSet {
/// use tokio::runtime::Runtime;
/// use tokio::task;
///
- /// let mut rt = Runtime::new().unwrap();
+ /// let rt = Runtime::new().unwrap();
/// let local = task::LocalSet::new();
- /// local.block_on(&mut rt, async {
+ /// local.block_on(&rt, async {
/// let join = task::spawn_local(async {
/// let blocking_result = task::block_in_place(|| {
/// // ...
@@ -329,9 +329,9 @@ impl LocalSet {
/// use tokio::runtime::Runtime;
/// use tokio::task;
///
- /// let mut rt = Runtime::new().unwrap();
+ /// let rt = Runtime::new().unwrap();
/// let local = task::LocalSet::new();
- /// local.block_on(&mut rt, async {
+ /// local.block_on(&rt, async {
/// let join = task::spawn_local(async {
/// let blocking_result = task::spawn_blocking(|| {
/// // ...
@@ -346,7 +346,9 @@ impl LocalSet {
/// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on
/// [in-place blocking]: fn@crate::task::block_in_place
/// [`spawn_blocking`]: fn@crate::task::spawn_blocking
- pub fn block_on<F>(&self, rt: &mut crate::runtime::Runtime, future: F) -> F::Output
+ #[cfg(feature = "rt")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
+ pub fn block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output
where
F: Future,
{
diff --git a/src/task/mod.rs b/src/task/mod.rs
index 5c89393..5dc5e72 100644
--- a/src/task/mod.rs
+++ b/src/task/mod.rs
@@ -102,7 +102,7 @@
//! # }
//! ```
//!
-//! `spawn`, `JoinHandle`, and `JoinError` are present when the "rt-core"
+//! `spawn`, `JoinHandle`, and `JoinError` are present when the "rt"
//! feature flag is enabled.
//!
//! [`task::spawn`]: crate::task::spawn()
@@ -159,7 +159,7 @@
//!
//! #### block_in_place
//!
-//! When using the [threaded runtime][rt-threaded], the [`task::block_in_place`]
+//! When using the [multi-threaded runtime][rt-multi-thread], the [`task::block_in_place`]
//! function is also available. Like `task::spawn_blocking`, this function
//! allows running a blocking operation from an asynchronous context. Unlike
//! `spawn_blocking`, however, `block_in_place` works by transitioning the
@@ -211,29 +211,26 @@
//!
//! [`task::spawn_blocking`]: crate::task::spawn_blocking
//! [`task::block_in_place`]: crate::task::block_in_place
-//! [rt-threaded]: ../runtime/index.html#threaded-scheduler
+//! [rt-multi-thread]: ../runtime/index.html#threaded-scheduler
//! [`task::yield_now`]: crate::task::yield_now()
//! [`thread::yield_now`]: std::thread::yield_now
-cfg_blocking! {
- mod blocking;
- pub use blocking::spawn_blocking;
- cfg_rt_threaded! {
- pub use blocking::block_in_place;
- }
-}
-
-cfg_rt_core! {
+cfg_rt! {
pub use crate::runtime::task::{JoinError, JoinHandle};
+ mod blocking;
+ pub use blocking::spawn_blocking;
+
mod spawn;
pub use spawn::spawn;
+ cfg_rt_multi_thread! {
+ pub use blocking::block_in_place;
+ }
+
mod yield_now;
pub use yield_now::yield_now;
-}
-cfg_rt_util! {
mod local;
pub use local::{spawn_local, LocalSet};
diff --git a/src/task/spawn.rs b/src/task/spawn.rs
index d6e7711..77acb57 100644
--- a/src/task/spawn.rs
+++ b/src/task/spawn.rs
@@ -3,7 +3,7 @@ use crate::task::JoinHandle;
use std::future::Future;
-doc_rt_core! {
+cfg_rt! {
/// Spawns a new asynchronous task, returning a
/// [`JoinHandle`](super::JoinHandle) for it.
///
@@ -18,7 +18,7 @@ doc_rt_core! {
///
/// This function must be called from the context of a Tokio runtime. Tasks running on
/// the Tokio runtime are always inside its context, but you can also enter the context
- /// using the [`Handle::enter`](crate::runtime::Handle::enter()) method.
+ /// using the [`Runtime::enter`](crate::runtime::Runtime::enter()) method.
///
/// # Examples
///
@@ -37,7 +37,7 @@ doc_rt_core! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut listener = TcpListener::bind("127.0.0.1:8080").await?;
+ /// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// loop {
/// let (socket, _) = listener.accept().await?;
diff --git a/src/task/yield_now.rs b/src/task/yield_now.rs
index e0e2084..251cb93 100644
--- a/src/task/yield_now.rs
+++ b/src/task/yield_now.rs
@@ -2,7 +2,7 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
-doc_rt_core! {
+cfg_rt! {
/// Yields execution back to the Tokio runtime.
///
/// A task yields by awaiting on `yield_now()`, and may resume when that
diff --git a/src/time/clock.rs b/src/time/clock.rs
index bd67d7a..fab7eca 100644
--- a/src/time/clock.rs
+++ b/src/time/clock.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
//! Source of time abstraction.
//!
//! By default, `std::time::Instant::now()` is used. However, when the
@@ -36,7 +38,18 @@ cfg_not_test_util! {
cfg_test_util! {
use crate::time::{Duration, Instant};
use std::sync::{Arc, Mutex};
- use crate::runtime::context;
+
+ cfg_rt! {
+ fn clock() -> Option<Clock> {
+ crate::runtime::context::clock()
+ }
+ }
+
+ cfg_not_rt! {
+ fn clock() -> Option<Clock> {
+ None
+ }
+ }
/// A handle to a source of time.
#[derive(Debug, Clone)]
@@ -58,14 +71,14 @@ cfg_test_util! {
/// The current value of `Instant::now()` is saved and all subsequent calls
/// to `Instant::now()` until the timer wheel is checked again will return the saved value.
/// Once the timer wheel is checked, time will immediately advance to the next registered
- /// `Delay`. This is useful for running tests that depend on time.
+ /// `Sleep`. This is useful for running tests that depend on time.
///
/// # Panics
///
/// Panics if time is already frozen or if called from outside of the Tokio
/// runtime.
pub fn pause() {
- let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime");
+ let clock = clock().expect("time cannot be frozen from outside the Tokio runtime");
clock.pause();
}
@@ -79,7 +92,7 @@ cfg_test_util! {
/// Panics if time is not frozen or if called from outside of the Tokio
/// runtime.
pub fn resume() {
- let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime");
+ let clock = clock().expect("time cannot be frozen from outside the Tokio runtime");
let mut inner = clock.inner.lock().unwrap();
if inner.unfrozen.is_some() {
@@ -99,14 +112,27 @@ cfg_test_util! {
/// Panics if time is not frozen or if called from outside of the Tokio
/// runtime.
pub async fn advance(duration: Duration) {
- let clock = context::clock().expect("time cannot be frozen from outside the Tokio runtime");
+ use crate::future::poll_fn;
+ use std::task::Poll;
+
+ let clock = clock().expect("time cannot be frozen from outside the Tokio runtime");
clock.advance(duration);
- crate::task::yield_now().await;
+
+ let mut yielded = false;
+ poll_fn(|cx| {
+ if yielded {
+ Poll::Ready(())
+ } else {
+ yielded = true;
+ cx.waker().wake_by_ref();
+ Poll::Pending
+ }
+ }).await;
}
/// Return the current instant, factoring in frozen time.
pub(crate) fn now() -> Instant {
- if let Some(clock) = context::clock() {
+ if let Some(clock) = clock() {
clock.now()
} else {
Instant::from_std(std::time::Instant::now())
diff --git a/src/time/delay_queue.rs b/src/time/delay_queue.rs
deleted file mode 100644
index 55ec7cd..0000000
--- a/src/time/delay_queue.rs
+++ /dev/null
@@ -1,887 +0,0 @@
-//! A queue of delayed elements.
-//!
-//! See [`DelayQueue`] for more details.
-//!
-//! [`DelayQueue`]: struct@DelayQueue
-
-use crate::time::wheel::{self, Wheel};
-use crate::time::{delay_until, Delay, Duration, Error, Instant};
-
-use slab::Slab;
-use std::cmp;
-use std::future::Future;
-use std::marker::PhantomData;
-use std::pin::Pin;
-use std::task::{self, Poll};
-
-/// A queue of delayed elements.
-///
-/// Once an element is inserted into the `DelayQueue`, it is yielded once the
-/// specified deadline has been reached.
-///
-/// # Usage
-///
-/// Elements are inserted into `DelayQueue` using the [`insert`] or
-/// [`insert_at`] methods. A deadline is provided with the item and a [`Key`] is
-/// returned. The key is used to remove the entry or to change the deadline at
-/// which it should be yielded back.
-///
-/// Once delays have been configured, the `DelayQueue` is used via its
-/// [`Stream`] implementation. [`poll`] is called. If an entry has reached its
-/// deadline, it is returned. If not, `Poll::Pending` indicating that the
-/// current task will be notified once the deadline has been reached.
-///
-/// # `Stream` implementation
-///
-/// Items are retrieved from the queue via [`Stream::poll`]. If no delays have
-/// expired, no items are returned. In this case, `NotReady` is returned and the
-/// current task is registered to be notified once the next item's delay has
-/// expired.
-///
-/// If no items are in the queue, i.e. `is_empty()` returns `true`, then `poll`
-/// returns `Ready(None)`. This indicates that the stream has reached an end.
-/// However, if a new item is inserted *after*, `poll` will once again start
-/// returning items or `NotReady.
-///
-/// Items are returned ordered by their expirations. Items that are configured
-/// to expire first will be returned first. There are no ordering guarantees
-/// for items configured to expire the same instant. Also note that delays are
-/// rounded to the closest millisecond.
-///
-/// # Implementation
-///
-/// The [`DelayQueue`] is backed by a separate instance of the same timer wheel used internally by
-/// Tokio's standalone timer utilities such as [`delay_for`]. Because of this, it offers the same
-/// performance and scalability benefits.
-///
-/// State associated with each entry is stored in a [`slab`]. This amortizes the cost of allocation,
-/// and allows reuse of the memory allocated for expired entires.
-///
-/// Capacity can be checked using [`capacity`] and allocated preemptively by using
-/// the [`reserve`] method.
-///
-/// # Usage
-///
-/// Using `DelayQueue` to manage cache entries.
-///
-/// ```rust,no_run
-/// use tokio::time::{delay_queue, DelayQueue, Error};
-///
-/// use futures::ready;
-/// use std::collections::HashMap;
-/// use std::task::{Context, Poll};
-/// use std::time::Duration;
-/// # type CacheKey = String;
-/// # type Value = String;
-///
-/// struct Cache {
-/// entries: HashMap<CacheKey, (Value, delay_queue::Key)>,
-/// expirations: DelayQueue<CacheKey>,
-/// }
-///
-/// const TTL_SECS: u64 = 30;
-///
-/// impl Cache {
-/// fn insert(&mut self, key: CacheKey, value: Value) {
-/// let delay = self.expirations
-/// .insert(key.clone(), Duration::from_secs(TTL_SECS));
-///
-/// self.entries.insert(key, (value, delay));
-/// }
-///
-/// fn get(&self, key: &CacheKey) -> Option<&Value> {
-/// self.entries.get(key)
-/// .map(|&(ref v, _)| v)
-/// }
-///
-/// fn remove(&mut self, key: &CacheKey) {
-/// if let Some((_, cache_key)) = self.entries.remove(key) {
-/// self.expirations.remove(&cache_key);
-/// }
-/// }
-///
-/// fn poll_purge(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
-/// while let Some(res) = ready!(self.expirations.poll_expired(cx)) {
-/// let entry = res?;
-/// self.entries.remove(entry.get_ref());
-/// }
-///
-/// Poll::Ready(Ok(()))
-/// }
-/// }
-/// ```
-///
-/// [`insert`]: method@Self::insert
-/// [`insert_at`]: method@Self::insert_at
-/// [`Key`]: struct@Key
-/// [`Stream`]: https://docs.rs/futures/0.1/futures/stream/trait.Stream.html
-/// [`poll`]: method@Self::poll
-/// [`Stream::poll`]: method@Self::poll
-/// [`DelayQueue`]: struct@DelayQueue
-/// [`delay_for`]: fn@super::delay_for
-/// [`slab`]: slab
-/// [`capacity`]: method@Self::capacity
-/// [`reserve`]: method@Self::reserve
-#[derive(Debug)]
-pub struct DelayQueue<T> {
- /// Stores data associated with entries
- slab: Slab<Data<T>>,
-
- /// Lookup structure tracking all delays in the queue
- wheel: Wheel<Stack<T>>,
-
- /// Delays that were inserted when already expired. These cannot be stored
- /// in the wheel
- expired: Stack<T>,
-
- /// Delay expiring when the *first* item in the queue expires
- delay: Option<Delay>,
-
- /// Wheel polling state
- poll: wheel::Poll,
-
- /// Instant at which the timer starts
- start: Instant,
-}
-
-/// An entry in `DelayQueue` that has expired and removed.
-///
-/// Values are returned by [`DelayQueue::poll`].
-///
-/// [`DelayQueue::poll`]: method@DelayQueue::poll
-#[derive(Debug)]
-pub struct Expired<T> {
- /// The data stored in the queue
- data: T,
-
- /// The expiration time
- deadline: Instant,
-
- /// The key associated with the entry
- key: Key,
-}
-
-/// Token to a value stored in a `DelayQueue`.
-///
-/// Instances of `Key` are returned by [`DelayQueue::insert`]. See [`DelayQueue`]
-/// documentation for more details.
-///
-/// [`DelayQueue`]: struct@DelayQueue
-/// [`DelayQueue::insert`]: method@DelayQueue::insert
-#[derive(Debug, Clone)]
-pub struct Key {
- index: usize,
-}
-
-#[derive(Debug)]
-struct Stack<T> {
- /// Head of the stack
- head: Option<usize>,
- _p: PhantomData<fn() -> T>,
-}
-
-#[derive(Debug)]
-struct Data<T> {
- /// The data being stored in the queue and will be returned at the requested
- /// instant.
- inner: T,
-
- /// The instant at which the item is returned.
- when: u64,
-
- /// Set to true when stored in the `expired` queue
- expired: bool,
-
- /// Next entry in the stack
- next: Option<usize>,
-
- /// Previous entry in the stack
- prev: Option<usize>,
-}
-
-/// Maximum number of entries the queue can handle
-const MAX_ENTRIES: usize = (1 << 30) - 1;
-
-impl<T> DelayQueue<T> {
- /// Creates a new, empty, `DelayQueue`
- ///
- /// The queue will not allocate storage until items are inserted into it.
- ///
- /// # Examples
- ///
- /// ```rust
- /// # use tokio::time::DelayQueue;
- /// let delay_queue: DelayQueue<u32> = DelayQueue::new();
- /// ```
- pub fn new() -> DelayQueue<T> {
- DelayQueue::with_capacity(0)
- }
-
- /// Creates a new, empty, `DelayQueue` with the specified capacity.
- ///
- /// The queue will be able to hold at least `capacity` elements without
- /// reallocating. If `capacity` is 0, the queue will not allocate for
- /// storage.
- ///
- /// # Examples
- ///
- /// ```rust
- /// # use tokio::time::DelayQueue;
- /// # use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::with_capacity(10);
- ///
- /// // These insertions are done without further allocation
- /// for i in 0..10 {
- /// delay_queue.insert(i, Duration::from_secs(i));
- /// }
- ///
- /// // This will make the queue allocate additional storage
- /// delay_queue.insert(11, Duration::from_secs(11));
- /// # }
- /// ```
- pub fn with_capacity(capacity: usize) -> DelayQueue<T> {
- DelayQueue {
- wheel: Wheel::new(),
- slab: Slab::with_capacity(capacity),
- expired: Stack::default(),
- delay: None,
- poll: wheel::Poll::new(0),
- start: Instant::now(),
- }
- }
-
- /// Inserts `value` into the queue set to expire at a specific instant in
- /// time.
- ///
- /// This function is identical to `insert`, but takes an `Instant` instead
- /// of a `Duration`.
- ///
- /// `value` is stored in the queue until `when` is reached. At which point,
- /// `value` will be returned from [`poll`]. If `when` has already been
- /// reached, then `value` is immediately made available to poll.
- ///
- /// The return value represents the insertion and is used at an argument to
- /// [`remove`] and [`reset`]. Note that [`Key`] is token and is reused once
- /// `value` is removed from the queue either by calling [`poll`] after
- /// `when` is reached or by calling [`remove`]. At this point, the caller
- /// must take care to not use the returned [`Key`] again as it may reference
- /// a different item in the queue.
- ///
- /// See [type] level documentation for more details.
- ///
- /// # Panics
- ///
- /// This function panics if `when` is too far in the future.
- ///
- /// # Examples
- ///
- /// Basic usage
- ///
- /// ```rust
- /// use tokio::time::{DelayQueue, Duration, Instant};
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// let key = delay_queue.insert_at(
- /// "foo", Instant::now() + Duration::from_secs(5));
- ///
- /// // Remove the entry
- /// let item = delay_queue.remove(&key);
- /// assert_eq!(*item.get_ref(), "foo");
- /// # }
- /// ```
- ///
- /// [`poll`]: method@Self::poll
- /// [`remove`]: method@Self::remove
- /// [`reset`]: method@Self::reset
- /// [`Key`]: struct@Key
- /// [type]: #
- pub fn insert_at(&mut self, value: T, when: Instant) -> Key {
- assert!(self.slab.len() < MAX_ENTRIES, "max entries exceeded");
-
- // Normalize the deadline. Values cannot be set to expire in the past.
- let when = self.normalize_deadline(when);
-
- // Insert the value in the store
- let key = self.slab.insert(Data {
- inner: value,
- when,
- expired: false,
- next: None,
- prev: None,
- });
-
- self.insert_idx(when, key);
-
- // Set a new delay if the current's deadline is later than the one of the new item
- let should_set_delay = if let Some(ref delay) = self.delay {
- let current_exp = self.normalize_deadline(delay.deadline());
- current_exp > when
- } else {
- true
- };
-
- if should_set_delay {
- let delay_time = self.start + Duration::from_millis(when);
- if let Some(ref mut delay) = &mut self.delay {
- delay.reset(delay_time);
- } else {
- self.delay = Some(delay_until(delay_time));
- }
- }
-
- Key::new(key)
- }
-
- /// Attempts to pull out the next value of the delay queue, registering the
- /// current task for wakeup if the value is not yet available, and returning
- /// None if the queue is exhausted.
- pub fn poll_expired(
- &mut self,
- cx: &mut task::Context<'_>,
- ) -> Poll<Option<Result<Expired<T>, Error>>> {
- let item = ready!(self.poll_idx(cx));
- Poll::Ready(item.map(|result| {
- result.map(|idx| {
- let data = self.slab.remove(idx);
- debug_assert!(data.next.is_none());
- debug_assert!(data.prev.is_none());
-
- Expired {
- key: Key::new(idx),
- data: data.inner,
- deadline: self.start + Duration::from_millis(data.when),
- }
- })
- }))
- }
-
- /// Inserts `value` into the queue set to expire after the requested duration
- /// elapses.
- ///
- /// This function is identical to `insert_at`, but takes a `Duration`
- /// instead of an `Instant`.
- ///
- /// `value` is stored in the queue until `when` is reached. At which point,
- /// `value` will be returned from [`poll`]. If `when` has already been
- /// reached, then `value` is immediately made available to poll.
- ///
- /// The return value represents the insertion and is used at an argument to
- /// [`remove`] and [`reset`]. Note that [`Key`] is token and is reused once
- /// `value` is removed from the queue either by calling [`poll`] after
- /// `when` is reached or by calling [`remove`]. At this point, the caller
- /// must take care to not use the returned [`Key`] again as it may reference
- /// a different item in the queue.
- ///
- /// See [type] level documentation for more details.
- ///
- /// # Panics
- ///
- /// This function panics if `timeout` is greater than the maximum supported
- /// duration.
- ///
- /// # Examples
- ///
- /// Basic usage
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// let key = delay_queue.insert("foo", Duration::from_secs(5));
- ///
- /// // Remove the entry
- /// let item = delay_queue.remove(&key);
- /// assert_eq!(*item.get_ref(), "foo");
- /// # }
- /// ```
- ///
- /// [`poll`]: method@Self::poll
- /// [`remove`]: method@Self::remove
- /// [`reset`]: method@Self::reset
- /// [`Key`]: struct@Key
- /// [type]: #
- pub fn insert(&mut self, value: T, timeout: Duration) -> Key {
- self.insert_at(value, Instant::now() + timeout)
- }
-
- fn insert_idx(&mut self, when: u64, key: usize) {
- use self::wheel::{InsertError, Stack};
-
- // Register the deadline with the timer wheel
- match self.wheel.insert(when, key, &mut self.slab) {
- Ok(_) => {}
- Err((_, InsertError::Elapsed)) => {
- self.slab[key].expired = true;
- // The delay is already expired, store it in the expired queue
- self.expired.push(key, &mut self.slab);
- }
- Err((_, err)) => panic!("invalid deadline; err={:?}", err),
- }
- }
-
- /// Removes the item associated with `key` from the queue.
- ///
- /// There must be an item associated with `key`. The function returns the
- /// removed item as well as the `Instant` at which it will the delay will
- /// have expired.
- ///
- /// # Panics
- ///
- /// The function panics if `key` is not contained by the queue.
- ///
- /// # Examples
- ///
- /// Basic usage
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// let key = delay_queue.insert("foo", Duration::from_secs(5));
- ///
- /// // Remove the entry
- /// let item = delay_queue.remove(&key);
- /// assert_eq!(*item.get_ref(), "foo");
- /// # }
- /// ```
- pub fn remove(&mut self, key: &Key) -> Expired<T> {
- use crate::time::wheel::Stack;
-
- // Special case the `expired` queue
- if self.slab[key.index].expired {
- self.expired.remove(&key.index, &mut self.slab);
- } else {
- self.wheel.remove(&key.index, &mut self.slab);
- }
-
- let data = self.slab.remove(key.index);
-
- Expired {
- key: Key::new(key.index),
- data: data.inner,
- deadline: self.start + Duration::from_millis(data.when),
- }
- }
-
- /// Sets the delay of the item associated with `key` to expire at `when`.
- ///
- /// This function is identical to `reset` but takes an `Instant` instead of
- /// a `Duration`.
- ///
- /// The item remains in the queue but the delay is set to expire at `when`.
- /// If `when` is in the past, then the item is immediately made available to
- /// the caller.
- ///
- /// # Panics
- ///
- /// This function panics if `when` is too far in the future or if `key` is
- /// not contained by the queue.
- ///
- /// # Examples
- ///
- /// Basic usage
- ///
- /// ```rust
- /// use tokio::time::{DelayQueue, Duration, Instant};
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// let key = delay_queue.insert("foo", Duration::from_secs(5));
- ///
- /// // "foo" is scheduled to be returned in 5 seconds
- ///
- /// delay_queue.reset_at(&key, Instant::now() + Duration::from_secs(10));
- ///
- /// // "foo"is now scheduled to be returned in 10 seconds
- /// # }
- /// ```
- pub fn reset_at(&mut self, key: &Key, when: Instant) {
- self.wheel.remove(&key.index, &mut self.slab);
-
- // Normalize the deadline. Values cannot be set to expire in the past.
- let when = self.normalize_deadline(when);
-
- self.slab[key.index].when = when;
- self.insert_idx(when, key.index);
-
- let next_deadline = self.next_deadline();
- if let (Some(ref mut delay), Some(deadline)) = (&mut self.delay, next_deadline) {
- delay.reset(deadline);
- }
- }
-
- /// Returns the next time poll as determined by the wheel
- fn next_deadline(&mut self) -> Option<Instant> {
- self.wheel
- .poll_at()
- .map(|poll_at| self.start + Duration::from_millis(poll_at))
- }
-
- /// Sets the delay of the item associated with `key` to expire after
- /// `timeout`.
- ///
- /// This function is identical to `reset_at` but takes a `Duration` instead
- /// of an `Instant`.
- ///
- /// The item remains in the queue but the delay is set to expire after
- /// `timeout`. If `timeout` is zero, then the item is immediately made
- /// available to the caller.
- ///
- /// # Panics
- ///
- /// This function panics if `timeout` is greater than the maximum supported
- /// duration or if `key` is not contained by the queue.
- ///
- /// # Examples
- ///
- /// Basic usage
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// let key = delay_queue.insert("foo", Duration::from_secs(5));
- ///
- /// // "foo" is scheduled to be returned in 5 seconds
- ///
- /// delay_queue.reset(&key, Duration::from_secs(10));
- ///
- /// // "foo"is now scheduled to be returned in 10 seconds
- /// # }
- /// ```
- pub fn reset(&mut self, key: &Key, timeout: Duration) {
- self.reset_at(key, Instant::now() + timeout);
- }
-
- /// Clears the queue, removing all items.
- ///
- /// After calling `clear`, [`poll`] will return `Ok(Ready(None))`.
- ///
- /// Note that this method has no effect on the allocated capacity.
- ///
- /// [`poll`]: method@Self::poll
- ///
- /// # Examples
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- ///
- /// delay_queue.insert("foo", Duration::from_secs(5));
- ///
- /// assert!(!delay_queue.is_empty());
- ///
- /// delay_queue.clear();
- ///
- /// assert!(delay_queue.is_empty());
- /// # }
- /// ```
- pub fn clear(&mut self) {
- self.slab.clear();
- self.expired = Stack::default();
- self.wheel = Wheel::new();
- self.delay = None;
- }
-
- /// Returns the number of elements the queue can hold without reallocating.
- ///
- /// # Examples
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- ///
- /// let delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
- /// assert_eq!(delay_queue.capacity(), 10);
- /// ```
- pub fn capacity(&self) -> usize {
- self.slab.capacity()
- }
-
- /// Returns the number of elements currently in the queue.
- ///
- /// # Examples
- ///
- /// ```rust
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue: DelayQueue<i32> = DelayQueue::with_capacity(10);
- /// assert_eq!(delay_queue.len(), 0);
- /// delay_queue.insert(3, Duration::from_secs(5));
- /// assert_eq!(delay_queue.len(), 1);
- /// # }
- /// ```
- pub fn len(&self) -> usize {
- self.slab.len()
- }
-
- /// Reserves capacity for at least `additional` more items to be queued
- /// without allocating.
- ///
- /// `reserve` does nothing if the queue already has sufficient capacity for
- /// `additional` more values. If more capacity is required, a new segment of
- /// memory will be allocated and all existing values will be copied into it.
- /// As such, if the queue is already very large, a call to `reserve` can end
- /// up being expensive.
- ///
- /// The queue may reserve more than `additional` extra space in order to
- /// avoid frequent reallocations.
- ///
- /// # Panics
- ///
- /// Panics if the new capacity exceeds the maximum number of entries the
- /// queue can contain.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- ///
- /// delay_queue.insert("hello", Duration::from_secs(10));
- /// delay_queue.reserve(10);
- ///
- /// assert!(delay_queue.capacity() >= 11);
- /// # }
- /// ```
- pub fn reserve(&mut self, additional: usize) {
- self.slab.reserve(additional);
- }
-
- /// Returns `true` if there are no items in the queue.
- ///
- /// Note that this function returns `false` even if all items have not yet
- /// expired and a call to `poll` will return `NotReady`.
- ///
- /// # Examples
- ///
- /// ```
- /// use tokio::time::DelayQueue;
- /// use std::time::Duration;
- ///
- /// # #[tokio::main]
- /// # async fn main() {
- /// let mut delay_queue = DelayQueue::new();
- /// assert!(delay_queue.is_empty());
- ///
- /// delay_queue.insert("hello", Duration::from_secs(5));
- /// assert!(!delay_queue.is_empty());
- /// # }
- /// ```
- pub fn is_empty(&self) -> bool {
- self.slab.is_empty()
- }
-
- /// Polls the queue, returning the index of the next slot in the slab that
- /// should be returned.
- ///
- /// A slot should be returned when the associated deadline has been reached.
- fn poll_idx(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Result<usize, Error>>> {
- use self::wheel::Stack;
-
- let expired = self.expired.pop(&mut self.slab);
-
- if expired.is_some() {
- return Poll::Ready(expired.map(Ok));
- }
-
- loop {
- if let Some(ref mut delay) = self.delay {
- if !delay.is_elapsed() {
- ready!(Pin::new(&mut *delay).poll(cx));
- }
-
- let now = crate::time::ms(delay.deadline() - self.start, crate::time::Round::Down);
-
- self.poll = wheel::Poll::new(now);
- }
-
- // We poll the wheel to get the next value out before finding the next deadline.
- let wheel_idx = self.wheel.poll(&mut self.poll, &mut self.slab);
-
- self.delay = self.next_deadline().map(delay_until);
-
- if let Some(idx) = wheel_idx {
- return Poll::Ready(Some(Ok(idx)));
- }
-
- if self.delay.is_none() {
- return Poll::Ready(None);
- }
- }
- }
-
- fn normalize_deadline(&self, when: Instant) -> u64 {
- let when = if when < self.start {
- 0
- } else {
- crate::time::ms(when - self.start, crate::time::Round::Up)
- };
-
- cmp::max(when, self.wheel.elapsed())
- }
-}
-
-// We never put `T` in a `Pin`...
-impl<T> Unpin for DelayQueue<T> {}
-
-impl<T> Default for DelayQueue<T> {
- fn default() -> DelayQueue<T> {
- DelayQueue::new()
- }
-}
-
-#[cfg(feature = "stream")]
-impl<T> futures_core::Stream for DelayQueue<T> {
- // DelayQueue seems much more specific, where a user may care that it
- // has reached capacity, so return those errors instead of panicking.
- type Item = Result<Expired<T>, Error>;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
- DelayQueue::poll_expired(self.get_mut(), cx)
- }
-}
-
-impl<T> wheel::Stack for Stack<T> {
- type Owned = usize;
- type Borrowed = usize;
- type Store = Slab<Data<T>>;
-
- fn is_empty(&self) -> bool {
- self.head.is_none()
- }
-
- fn push(&mut self, item: Self::Owned, store: &mut Self::Store) {
- // Ensure the entry is not already in a stack.
- debug_assert!(store[item].next.is_none());
- debug_assert!(store[item].prev.is_none());
-
- // Remove the old head entry
- let old = self.head.take();
-
- if let Some(idx) = old {
- store[idx].prev = Some(item);
- }
-
- store[item].next = old;
- self.head = Some(item)
- }
-
- fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned> {
- if let Some(idx) = self.head {
- self.head = store[idx].next;
-
- if let Some(idx) = self.head {
- store[idx].prev = None;
- }
-
- store[idx].next = None;
- debug_assert!(store[idx].prev.is_none());
-
- Some(idx)
- } else {
- None
- }
- }
-
- fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store) {
- assert!(store.contains(*item));
-
- // Ensure that the entry is in fact contained by the stack
- debug_assert!({
- // This walks the full linked list even if an entry is found.
- let mut next = self.head;
- let mut contains = false;
-
- while let Some(idx) = next {
- if idx == *item {
- debug_assert!(!contains);
- contains = true;
- }
-
- next = store[idx].next;
- }
-
- contains
- });
-
- if let Some(next) = store[*item].next {
- store[next].prev = store[*item].prev;
- }
-
- if let Some(prev) = store[*item].prev {
- store[prev].next = store[*item].next;
- } else {
- self.head = store[*item].next;
- }
-
- store[*item].next = None;
- store[*item].prev = None;
- }
-
- fn when(item: &Self::Borrowed, store: &Self::Store) -> u64 {
- store[*item].when
- }
-}
-
-impl<T> Default for Stack<T> {
- fn default() -> Stack<T> {
- Stack {
- head: None,
- _p: PhantomData,
- }
- }
-}
-
-impl Key {
- pub(crate) fn new(index: usize) -> Key {
- Key { index }
- }
-}
-
-impl<T> Expired<T> {
- /// Returns a reference to the inner value.
- pub fn get_ref(&self) -> &T {
- &self.data
- }
-
- /// Returns a mutable reference to the inner value.
- pub fn get_mut(&mut self) -> &mut T {
- &mut self.data
- }
-
- /// Consumes `self` and returns the inner value.
- pub fn into_inner(self) -> T {
- self.data
- }
-
- /// Returns the deadline that the expiration was set to.
- pub fn deadline(&self) -> Instant {
- self.deadline
- }
-}
diff --git a/src/time/driver/atomic_stack.rs b/src/time/driver/atomic_stack.rs
index 7e5a83f..5dcc472 100644
--- a/src/time/driver/atomic_stack.rs
+++ b/src/time/driver/atomic_stack.rs
@@ -1,5 +1,5 @@
use crate::time::driver::Entry;
-use crate::time::Error;
+use crate::time::error::Error;
use std::ptr;
use std::sync::atomic::AtomicPtr;
@@ -95,7 +95,7 @@ impl Iterator for AtomicStackEntries {
type Item = Arc<Entry>;
fn next(&mut self) -> Option<Self::Item> {
- if self.ptr.is_null() {
+ if self.ptr.is_null() || self.ptr == SHUTDOWN {
return None;
}
@@ -118,7 +118,7 @@ impl Drop for AtomicStackEntries {
fn drop(&mut self) {
for entry in self {
// Flag the entry as errored
- entry.error();
+ entry.error(Error::shutdown());
}
}
}
diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs
index b375ee9..b40cae7 100644
--- a/src/time/driver/entry.rs
+++ b/src/time/driver/entry.rs
@@ -1,17 +1,17 @@
use crate::loom::sync::atomic::AtomicU64;
use crate::sync::AtomicWaker;
use crate::time::driver::{Handle, Inner};
-use crate::time::{Duration, Error, Instant};
+use crate::time::{error::Error, Duration, Instant};
use std::cell::UnsafeCell;
use std::ptr;
-use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::SeqCst;
+use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::{Arc, Weak};
use std::task::{self, Poll};
use std::u64;
-/// Internal state shared between a `Delay` instance and the timer.
+/// Internal state shared between a `Sleep` instance and the timer.
///
/// This struct is used as a node in two intrusive data structures:
///
@@ -28,7 +28,7 @@ pub(crate) struct Entry {
time: CachePadded<UnsafeCell<Time>>,
/// Timer internals. Using a weak pointer allows the timer to shutdown
- /// without all `Delay` instances having completed.
+ /// without all `Sleep` instances having completed.
///
/// When empty, it means that the entry has not yet been linked with a
/// timer instance.
@@ -45,6 +45,11 @@ pub(crate) struct Entry {
/// instant, this value is changed.
state: AtomicU64,
+ /// Stores the actual error. If `state` indicates that an error occurred,
+ /// this is guaranteed to be a non-zero value representing the first error
+ /// that occurred. Otherwise its value is undefined.
+ error: AtomicU8,
+
/// Task to notify once the deadline is reached.
waker: AtomicWaker,
@@ -64,8 +69,8 @@ pub(crate) struct Entry {
/// When the entry expires, relative to the `start` of the timer
/// (Inner::start). This is only used by the timer.
///
- /// A `Delay` instance can be reset to a different deadline by the thread
- /// that owns the `Delay` instance. In this case, the timer thread will not
+ /// A `Sleep` instance can be reset to a different deadline by the thread
+ /// that owns the `Sleep` instance. In this case, the timer thread will not
/// immediately know that this has happened. The timer thread must know the
/// last deadline that it saw as it uses this value to locate the entry in
/// its wheel.
@@ -78,7 +83,7 @@ pub(crate) struct Entry {
/// Next entry in the State's linked list.
///
/// This is only accessed by the timer
- pub(super) next_stack: UnsafeCell<Option<Arc<Entry>>>,
+ pub(crate) next_stack: UnsafeCell<Option<Arc<Entry>>>,
/// Previous entry in the State's linked list.
///
@@ -86,10 +91,10 @@ pub(crate) struct Entry {
/// entry.
///
/// This is a weak reference.
- pub(super) prev_stack: UnsafeCell<*const Entry>,
+ pub(crate) prev_stack: UnsafeCell<*const Entry>,
}
-/// Stores the info for `Delay`.
+/// Stores the info for `Sleep`.
#[derive(Debug)]
pub(crate) struct Time {
pub(crate) deadline: Instant,
@@ -107,11 +112,12 @@ const ERROR: u64 = u64::MAX;
impl Entry {
pub(crate) fn new(handle: &Handle, deadline: Instant, duration: Duration) -> Arc<Entry> {
let inner = handle.inner().unwrap();
- let entry: Entry;
- // Increment the number of active timeouts
- if inner.increment().is_err() {
- entry = Entry::new2(deadline, duration, Weak::new(), ERROR)
+ // Attempt to increment the number of active timeouts
+ let entry = if let Err(err) = inner.increment() {
+ let entry = Entry::new2(deadline, duration, Weak::new(), ERROR);
+ entry.error(err);
+ entry
} else {
let when = inner.normalize_deadline(deadline);
let state = if when <= inner.elapsed() {
@@ -119,12 +125,12 @@ impl Entry {
} else {
when
};
- entry = Entry::new2(deadline, duration, Arc::downgrade(&inner), state);
- }
+ Entry::new2(deadline, duration, Arc::downgrade(&inner), state)
+ };
let entry = Arc::new(entry);
- if inner.queue(&entry).is_err() {
- entry.error();
+ if let Err(err) = inner.queue(&entry) {
+ entry.error(err);
}
entry
@@ -141,6 +147,10 @@ impl Entry {
&mut *self.time.0.get()
}
+ pub(crate) fn when(&self) -> u64 {
+ self.when_internal().expect("invalid internal state")
+ }
+
/// The current entry state as known by the timer. This is not the value of
/// `state`, but lets the timer know how to converge its state to `state`.
pub(crate) fn when_internal(&self) -> Option<u64> {
@@ -190,7 +200,12 @@ impl Entry {
self.waker.wake();
}
- pub(crate) fn error(&self) {
+ pub(crate) fn error(&self, error: Error) {
+ // Record the precise nature of the error, if there isn't already an
+ // error present. If we don't actually transition to the error state
+ // below, that's fine, as the error details we set here will be ignored.
+ self.error.compare_and_swap(0, error.as_u8(), SeqCst);
+
// Only transition to the error state if not currently elapsed
let mut curr = self.state.load(SeqCst);
@@ -235,7 +250,7 @@ impl Entry {
if is_elapsed(curr) {
return Poll::Ready(if curr == ERROR {
- Err(Error::shutdown())
+ Err(Error::from_u8(self.error.load(SeqCst)))
} else {
Ok(())
});
@@ -247,7 +262,7 @@ impl Entry {
if is_elapsed(curr) {
return Poll::Ready(if curr == ERROR {
- Err(Error::shutdown())
+ Err(Error::from_u8(self.error.load(SeqCst)))
} else {
Ok(())
});
@@ -310,6 +325,7 @@ impl Entry {
waker: AtomicWaker::new(),
state: AtomicU64::new(state),
queued: AtomicBool::new(false),
+ error: AtomicU8::new(0),
next_atomic: UnsafeCell::new(ptr::null_mut()),
when: UnsafeCell::new(None),
next_stack: UnsafeCell::new(None),
diff --git a/src/time/driver/handle.rs b/src/time/driver/handle.rs
index 38b1761..54b8a8b 100644
--- a/src/time/driver/handle.rs
+++ b/src/time/driver/handle.rs
@@ -1,4 +1,3 @@
-use crate::runtime::context;
use crate::time::driver::Inner;
use std::fmt;
use std::sync::{Arc, Weak};
@@ -15,22 +14,62 @@ impl Handle {
Handle { inner }
}
- /// Tries to get a handle to the current timer.
- ///
- /// # Panics
- ///
- /// This function panics if there is no current timer set.
- pub(crate) fn current() -> Self {
- context::time_handle()
- .expect("there is no timer running, must be called from the context of Tokio runtime")
- }
-
/// Tries to return a strong ref to the inner
pub(crate) fn inner(&self) -> Option<Arc<Inner>> {
self.inner.upgrade()
}
}
+cfg_rt! {
+ impl Handle {
+ /// Tries to get a handle to the current timer.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current timer set.
+ ///
+ /// It can be triggered when `Builder::enable_time()` or
+ /// `Builder::enable_all()` are not included in the builder.
+ ///
+ /// It can also panic whenever a timer is created outside of a Tokio
+ /// runtime. That is why `rt.block_on(delay_for(...))` will panic,
+ /// since the function is executed outside of the runtime.
+ /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't
+ /// panic. And this is because wrapping the function on an async makes it
+ /// lazy, and so gets executed inside the runtime successfuly without
+ /// panicking.
+ pub(crate) fn current() -> Self {
+ crate::runtime::context::time_handle()
+ .expect("there is no timer running, must be called from the context of Tokio runtime")
+ }
+ }
+}
+
+cfg_not_rt! {
+ impl Handle {
+ /// Tries to get a handle to the current timer.
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current timer set.
+ ///
+ /// It can be triggered when `Builder::enable_time()` or
+ /// `Builder::enable_all()` are not included in the builder.
+ ///
+ /// It can also panic whenever a timer is created outside of a Tokio
+ /// runtime. That is why `rt.block_on(delay_for(...))` will panic,
+ /// since the function is executed outside of the runtime.
+ /// Whereas `rt.block_on(async {delay_for(...).await})` doesn't
+ /// panic. And this is because wrapping the function on an async makes it
+ /// lazy, and so gets executed inside the runtime successfuly without
+ /// panicking.
+ pub(crate) fn current() -> Self {
+ panic!("there is no timer running, must be called from the context of Tokio runtime or \
+ `rt` is not enabled")
+ }
+ }
+}
+
impl fmt::Debug for Handle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Handle")
diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs
index 554042f..8532c55 100644
--- a/src/time/driver/mod.rs
+++ b/src/time/driver/mod.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
//! Time driver
mod atomic_stack;
@@ -9,15 +11,9 @@ pub(super) use self::entry::Entry;
mod handle;
pub(crate) use self::handle::Handle;
-mod registration;
-pub(crate) use self::registration::Registration;
-
-mod stack;
-use self::stack::Stack;
-
use crate::loom::sync::atomic::{AtomicU64, AtomicUsize};
use crate::park::{Park, Unpark};
-use crate::time::{wheel, Error};
+use crate::time::{error::Error, wheel};
use crate::time::{Clock, Duration, Instant};
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
@@ -26,12 +22,12 @@ use std::sync::Arc;
use std::usize;
use std::{cmp, fmt};
-/// Time implementation that drives [`Delay`][delay], [`Interval`][interval], and [`Timeout`][timeout].
+/// Time implementation that drives [`Sleep`][sleep], [`Interval`][interval], and [`Timeout`][timeout].
///
/// A `Driver` instance tracks the state necessary for managing time and
-/// notifying the [`Delay`][delay] instances once their deadlines are reached.
+/// notifying the [`Sleep`][sleep] instances once their deadlines are reached.
///
-/// It is expected that a single instance manages many individual [`Delay`][delay]
+/// It is expected that a single instance manages many individual [`Sleep`][sleep]
/// instances. The `Driver` implementation is thread-safe and, as such, is able
/// to handle callers from across threads.
///
@@ -42,9 +38,9 @@ use std::{cmp, fmt};
/// The driver has a resolution of one millisecond. Any unit of time that falls
/// between milliseconds are rounded up to the next millisecond.
///
-/// When an instance is dropped, any outstanding [`Delay`][delay] instance that has not
+/// When an instance is dropped, any outstanding [`Sleep`][sleep] instance that has not
/// elapsed will be notified with an error. At this point, calling `poll` on the
-/// [`Delay`][delay] instance will result in panic.
+/// [`Sleep`][sleep] instance will result in panic.
///
/// # Implementation
///
@@ -71,29 +67,32 @@ use std::{cmp, fmt};
/// * Level 5: 64 x ~12 day slots.
///
/// When the timer processes entries at level zero, it will notify all the
-/// `Delay` instances as their deadlines have been reached. For all higher
+/// `Sleep` instances as their deadlines have been reached. For all higher
/// levels, all entries will be redistributed across the wheel at the next level
-/// down. Eventually, as time progresses, entries will [`Delay`][delay] instances will
+/// down. Eventually, as time progresses, entries with [`Sleep`][sleep] instances will
/// either be canceled (dropped) or their associated entries will reach level
/// zero and be notified.
///
/// [paper]: http://www.cs.columbia.edu/~nahum/w6998/papers/ton97-timing-wheels.pdf
-/// [delay]: crate::time::Delay
+/// [sleep]: crate::time::Sleep
/// [timeout]: crate::time::Timeout
/// [interval]: crate::time::Interval
#[derive(Debug)]
-pub(crate) struct Driver<T> {
+pub(crate) struct Driver<T: Park> {
/// Shared state
inner: Arc<Inner>,
/// Timer wheel
- wheel: wheel::Wheel<Stack>,
+ wheel: wheel::Wheel,
/// Thread parker. The `Driver` park implementation delegates to this.
park: T,
/// Source of "now" instances
clock: Clock,
+
+ /// True if the driver is being shutdown
+ is_shutdown: bool,
}
/// Timer state shared between `Driver`, `Handle`, and `Registration`.
@@ -135,12 +134,13 @@ where
wheel: wheel::Wheel::new(),
park,
clock,
+ is_shutdown: false,
}
}
/// Returns a handle to the timer.
///
- /// The `Handle` is how `Delay` instances are created. The `Delay` instances
+ /// The `Handle` is how `Sleep` instances are created. The `Sleep` instances
/// can either be created directly or the `Handle` instance can be passed to
/// `with_default`, setting the timer as the default timer for the execution
/// context.
@@ -159,9 +159,8 @@ where
self.clock.now() - self.inner.start,
crate::time::Round::Down,
);
- let mut poll = wheel::Poll::new(now);
- while let Some(entry) = self.wheel.poll(&mut poll, &mut ()) {
+ while let Some(entry) = self.wheel.poll(now) {
let when = entry.when_internal().expect("invalid internal entry state");
// Fire the entry
@@ -189,7 +188,7 @@ where
self.clear_entry(&entry);
}
(None, Some(when)) => {
- // Queue the entry
+ // Add the entry to the timer wheel
self.add_entry(entry, when);
}
(Some(_), Some(next)) => {
@@ -201,19 +200,17 @@ where
}
fn clear_entry(&mut self, entry: &Arc<Entry>) {
- self.wheel.remove(entry, &mut ());
+ self.wheel.remove(entry);
entry.set_when_internal(None);
}
/// Fires the entry if it needs to, otherwise queue it to be processed later.
- ///
- /// Returns `None` if the entry was fired.
fn add_entry(&mut self, entry: Arc<Entry>, when: u64) {
- use crate::time::wheel::InsertError;
+ use crate::time::error::InsertError;
entry.set_when_internal(Some(when));
- match self.wheel.insert(when, entry, &mut ()) {
+ match self.wheel.insert(when, entry) {
Ok(_) => {}
Err((entry, InsertError::Elapsed)) => {
// The entry's deadline has elapsed, so fire it and update the
@@ -225,7 +222,7 @@ where
// The entry's deadline is invalid, so error it and update the
// internal state accordingly.
entry.set_when_internal(None);
- entry.error();
+ entry.error(Error::invalid());
}
}
}
@@ -303,10 +300,12 @@ where
Ok(())
}
-}
-impl<T> Drop for Driver<T> {
- fn drop(&mut self) {
+ fn shutdown(&mut self) {
+ if self.is_shutdown {
+ return;
+ }
+
use std::u64;
// Shutdown the stack of entries to process, preventing any new entries
@@ -314,11 +313,24 @@ impl<T> Drop for Driver<T> {
self.inner.process.shutdown();
// Clear the wheel, using u64::MAX allows us to drain everything
- let mut poll = wheel::Poll::new(u64::MAX);
+ let end_of_time = u64::MAX;
- while let Some(entry) = self.wheel.poll(&mut poll, &mut ()) {
- entry.error();
+ while let Some(entry) = self.wheel.poll(end_of_time) {
+ entry.error(Error::shutdown());
}
+
+ self.park.shutdown();
+
+ self.is_shutdown = true;
+ }
+}
+
+impl<T> Drop for Driver<T>
+where
+ T: Park,
+{
+ fn drop(&mut self) {
+ self.shutdown();
}
}
@@ -368,6 +380,10 @@ impl Inner {
debug_assert!(prev <= MAX_TIMEOUTS);
}
+ /// add the entry to the "process queue". entries are not immediately
+ /// pushed into the timer wheel but are instead pushed into the
+ /// process queue and then moved from the process queue into the timer
+ /// wheel on next `process`
fn queue(&self, entry: &Arc<Entry>) -> Result<(), Error> {
if self.process.push(entry)? {
// The timer is notified so that it can process the timeout
diff --git a/src/time/driver/registration.rs b/src/time/driver/registration.rs
deleted file mode 100644
index 3a0b345..0000000
--- a/src/time/driver/registration.rs
+++ /dev/null
@@ -1,56 +0,0 @@
-use crate::time::driver::{Entry, Handle};
-use crate::time::{Duration, Error, Instant};
-
-use std::sync::Arc;
-use std::task::{self, Poll};
-
-/// Registration with a timer.
-///
-/// The association between a `Delay` instance and a timer is done lazily in
-/// `poll`
-#[derive(Debug)]
-pub(crate) struct Registration {
- entry: Arc<Entry>,
-}
-
-impl Registration {
- pub(crate) fn new(deadline: Instant, duration: Duration) -> Registration {
- let handle = Handle::current();
-
- Registration {
- entry: Entry::new(&handle, deadline, duration),
- }
- }
-
- pub(crate) fn deadline(&self) -> Instant {
- self.entry.time_ref().deadline
- }
-
- pub(crate) fn reset(&mut self, deadline: Instant) {
- unsafe {
- self.entry.time_mut().deadline = deadline;
- }
-
- Entry::reset(&mut self.entry);
- }
-
- pub(crate) fn is_elapsed(&self) -> bool {
- self.entry.is_elapsed()
- }
-
- pub(crate) fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- self.entry.poll_elapsed(cx).map(move |r| {
- coop.made_progress();
- r
- })
- }
-}
-
-impl Drop for Registration {
- fn drop(&mut self) {
- Entry::cancel(&self.entry);
- }
-}
diff --git a/src/time/driver/stack.rs b/src/time/driver/stack.rs
deleted file mode 100644
index 3e2924f..0000000
--- a/src/time/driver/stack.rs
+++ /dev/null
@@ -1,121 +0,0 @@
-use crate::time::driver::Entry;
-use crate::time::wheel;
-
-use std::ptr;
-use std::sync::Arc;
-
-/// A doubly linked stack
-#[derive(Debug)]
-pub(crate) struct Stack {
- head: Option<Arc<Entry>>,
-}
-
-impl Default for Stack {
- fn default() -> Stack {
- Stack { head: None }
- }
-}
-
-impl wheel::Stack for Stack {
- type Owned = Arc<Entry>;
- type Borrowed = Entry;
- type Store = ();
-
- fn is_empty(&self) -> bool {
- self.head.is_none()
- }
-
- fn push(&mut self, entry: Self::Owned, _: &mut Self::Store) {
- // Get a pointer to the entry to for the prev link
- let ptr: *const Entry = &*entry as *const _;
-
- // Remove the old head entry
- let old = self.head.take();
-
- unsafe {
- // Ensure the entry is not already in a stack.
- debug_assert!((*entry.next_stack.get()).is_none());
- debug_assert!((*entry.prev_stack.get()).is_null());
-
- if let Some(ref entry) = old.as_ref() {
- debug_assert!({
- // The head is not already set to the entry
- ptr != &***entry as *const _
- });
-
- // Set the previous link on the old head
- *entry.prev_stack.get() = ptr;
- }
-
- // Set this entry's next pointer
- *entry.next_stack.get() = old;
- }
-
- // Update the head pointer
- self.head = Some(entry);
- }
-
- /// Pops an item from the stack
- fn pop(&mut self, _: &mut ()) -> Option<Arc<Entry>> {
- let entry = self.head.take();
-
- unsafe {
- if let Some(entry) = entry.as_ref() {
- self.head = (*entry.next_stack.get()).take();
-
- if let Some(entry) = self.head.as_ref() {
- *entry.prev_stack.get() = ptr::null();
- }
-
- *entry.prev_stack.get() = ptr::null();
- }
- }
-
- entry
- }
-
- fn remove(&mut self, entry: &Entry, _: &mut ()) {
- unsafe {
- // Ensure that the entry is in fact contained by the stack
- debug_assert!({
- // This walks the full linked list even if an entry is found.
- let mut next = self.head.as_ref();
- let mut contains = false;
-
- while let Some(n) = next {
- if entry as *const _ == &**n as *const _ {
- debug_assert!(!contains);
- contains = true;
- }
-
- next = (*n.next_stack.get()).as_ref();
- }
-
- contains
- });
-
- // Unlink `entry` from the next node
- let next = (*entry.next_stack.get()).take();
-
- if let Some(next) = next.as_ref() {
- (*next.prev_stack.get()) = *entry.prev_stack.get();
- }
-
- // Unlink `entry` from the prev node
-
- if let Some(prev) = (*entry.prev_stack.get()).as_ref() {
- *prev.next_stack.get() = next;
- } else {
- // It is the head
- self.head = next;
- }
-
- // Unset the prev pointer
- *entry.prev_stack.get() = ptr::null();
- }
- }
-
- fn when(item: &Entry, _: &()) -> u64 {
- item.when_internal().expect("invalid internal state")
- }
-}
diff --git a/src/time/error.rs b/src/time/error.rs
index 0667b97..24395c4 100644
--- a/src/time/error.rs
+++ b/src/time/error.rs
@@ -1,3 +1,5 @@
+//! Time error types.
+
use self::Kind::*;
use std::error;
use std::fmt;
@@ -13,7 +15,7 @@ use std::fmt;
/// succeed in the future.
///
/// * `at_capacity` occurs when a timer operation is attempted, but the timer
-/// instance is currently handling its maximum number of outstanding delays.
+/// instance is currently handling its maximum number of outstanding sleep instances.
/// In this case, the operation is not able to be performed at the current
/// moment, and `at_capacity` is returned. This is a transient error, i.e., at
/// some point in the future, if the operation is attempted again, it might
@@ -24,12 +26,26 @@ use std::fmt;
#[derive(Debug)]
pub struct Error(Kind);
-#[derive(Debug)]
+#[derive(Debug, Clone, Copy)]
+#[repr(u8)]
enum Kind {
- Shutdown,
- AtCapacity,
+ Shutdown = 1,
+ AtCapacity = 2,
+ Invalid = 3,
}
+/// Error returned by `Timeout`.
+#[derive(Debug, PartialEq)]
+pub struct Elapsed(());
+
+#[derive(Debug)]
+pub(crate) enum InsertError {
+ Elapsed,
+ Invalid,
+}
+
+// ===== impl Error =====
+
impl Error {
/// Creates an error representing a shutdown timer.
pub fn shutdown() -> Error {
@@ -38,10 +54,7 @@ impl Error {
/// Returns `true` if the error was caused by the timer being shutdown.
pub fn is_shutdown(&self) -> bool {
- match self.0 {
- Kind::Shutdown => true,
- _ => false,
- }
+ matches!(self.0, Kind::Shutdown)
}
/// Creates an error representing a timer at capacity.
@@ -51,10 +64,30 @@ impl Error {
/// Returns `true` if the error was caused by the timer being at capacity.
pub fn is_at_capacity(&self) -> bool {
- match self.0 {
- Kind::AtCapacity => true,
- _ => false,
- }
+ matches!(self.0, Kind::AtCapacity)
+ }
+
+ /// Create an error representing a misconfigured timer.
+ pub fn invalid() -> Error {
+ Error(Invalid)
+ }
+
+ /// Returns `true` if the error was caused by the timer being misconfigured.
+ pub fn is_invalid(&self) -> bool {
+ matches!(self.0, Kind::Invalid)
+ }
+
+ pub(crate) fn as_u8(&self) -> u8 {
+ self.0 as u8
+ }
+
+ pub(crate) fn from_u8(n: u8) -> Self {
+ Error(match n {
+ 1 => Shutdown,
+ 2 => AtCapacity,
+ 3 => Invalid,
+ _ => panic!("u8 does not correspond to any time error variant"),
+ })
}
}
@@ -66,7 +99,30 @@ impl fmt::Display for Error {
let descr = match self.0 {
Shutdown => "the timer is shutdown, must be called from the context of Tokio runtime",
AtCapacity => "timer is at capacity and cannot create a new entry",
+ Invalid => "timer duration exceeds maximum duration",
};
write!(fmt, "{}", descr)
}
}
+
+// ===== impl Elapsed =====
+
+impl Elapsed {
+ pub(crate) fn new() -> Self {
+ Elapsed(())
+ }
+}
+
+impl fmt::Display for Elapsed {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ "deadline has elapsed".fmt(fmt)
+ }
+}
+
+impl std::error::Error for Elapsed {}
+
+impl From<Elapsed> for std::io::Error {
+ fn from(_err: Elapsed) -> std::io::Error {
+ std::io::ErrorKind::TimedOut.into()
+ }
+}
diff --git a/src/time/instant.rs b/src/time/instant.rs
index f2cb4bc..f4d6eac 100644
--- a/src/time/instant.rs
+++ b/src/time/instant.rs
@@ -4,8 +4,32 @@ use std::fmt;
use std::ops;
use std::time::Duration;
-/// A measurement of the system clock, useful for talking to
-/// external entities like the file system or other processes.
+/// A measurement of a monotonically nondecreasing clock.
+/// Opaque and useful only with `Duration`.
+///
+/// Instants are always guaranteed to be no less than any previously measured
+/// instant when created, and are often useful for tasks such as measuring
+/// benchmarks or timing how long an operation takes.
+///
+/// Note, however, that instants are not guaranteed to be **steady**. In other
+/// words, each tick of the underlying clock may not be the same length (e.g.
+/// some seconds may be longer than others). An instant may jump forwards or
+/// experience time dilation (slow down or speed up), but it will never go
+/// backwards.
+///
+/// Instants are opaque types that can only be compared to one another. There is
+/// no method to get "the number of seconds" from an instant. Instead, it only
+/// allows measuring the duration between two instants (or comparing two
+/// instants).
+///
+/// The size of an `Instant` struct may vary depending on the target operating
+/// system.
+///
+/// # Note
+///
+/// This type wraps the inner `std` variant and is used to align the Tokio
+/// clock for uses of `now()`. This can be useful for testing where you can
+/// take advantage of `time::pause()` and `time::advance()`.
#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct Instant {
std: std::time::Instant,
@@ -50,12 +74,12 @@ impl Instant {
/// # Examples
///
/// ```
- /// use tokio::time::{Duration, Instant, delay_for};
+ /// use tokio::time::{Duration, Instant, sleep};
///
/// #[tokio::main]
/// async fn main() {
/// let now = Instant::now();
- /// delay_for(Duration::new(1, 0)).await;
+ /// sleep(Duration::new(1, 0)).await;
/// let new_now = Instant::now();
/// println!("{:?}", new_now.checked_duration_since(now));
/// println!("{:?}", now.checked_duration_since(new_now)); // None
@@ -71,12 +95,12 @@ impl Instant {
/// # Examples
///
/// ```
- /// use tokio::time::{Duration, Instant, delay_for};
+ /// use tokio::time::{Duration, Instant, sleep};
///
/// #[tokio::main]
/// async fn main() {
/// let now = Instant::now();
- /// delay_for(Duration::new(1, 0)).await;
+ /// sleep(Duration::new(1, 0)).await;
/// let new_now = Instant::now();
/// println!("{:?}", new_now.saturating_duration_since(now));
/// println!("{:?}", now.saturating_duration_since(new_now)); // 0ns
@@ -97,13 +121,13 @@ impl Instant {
/// # Examples
///
/// ```
- /// use tokio::time::{Duration, Instant, delay_for};
+ /// use tokio::time::{Duration, Instant, sleep};
///
/// #[tokio::main]
/// async fn main() {
/// let instant = Instant::now();
/// let three_secs = Duration::from_secs(3);
- /// delay_for(three_secs).await;
+ /// sleep(three_secs).await;
/// assert!(instant.elapsed() >= three_secs);
/// }
/// ```
diff --git a/src/time/interval.rs b/src/time/interval.rs
index 1fa21e6..c7c58e1 100644
--- a/src/time/interval.rs
+++ b/src/time/interval.rs
@@ -1,5 +1,5 @@
use crate::future::poll_fn;
-use crate::time::{delay_until, Delay, Duration, Instant};
+use crate::time::{sleep_until, Duration, Instant, Sleep};
use std::future::Future;
use std::pin::Pin;
@@ -36,12 +36,12 @@ use std::task::{Context, Poll};
///
/// A simple example using `interval` to execute a task every two seconds.
///
-/// The difference between `interval` and [`delay_for`] is that an `interval`
+/// The difference between `interval` and [`sleep`] is that an `interval`
/// measures the time since the last tick, which means that `.tick().await`
/// may wait for a shorter time than the duration specified for the interval
/// if some time has passed between calls to `.tick().await`.
///
-/// If the tick in the example below was replaced with [`delay_for`], the task
+/// If the tick in the example below was replaced with [`sleep`], the task
/// would only be executed once every three seconds, and not every two
/// seconds.
///
@@ -50,7 +50,7 @@ use std::task::{Context, Poll};
///
/// async fn task_that_takes_a_second() {
/// println!("hello");
-/// time::delay_for(time::Duration::from_secs(1)).await
+/// time::sleep(time::Duration::from_secs(1)).await
/// }
///
/// #[tokio::main]
@@ -63,7 +63,7 @@ use std::task::{Context, Poll};
/// }
/// ```
///
-/// [`delay_for`]: crate::time::delay_for()
+/// [`sleep`]: crate::time::sleep()
pub fn interval(period: Duration) -> Interval {
assert!(period > Duration::new(0, 0), "`period` must be non-zero.");
@@ -71,7 +71,7 @@ pub fn interval(period: Duration) -> Interval {
}
/// Creates new `Interval` that yields with interval of `period` with the
-/// first tick completing at `at`.
+/// first tick completing at `start`.
///
/// An interval will tick indefinitely. At any time, the `Interval` value can be
/// dropped. This cancels the interval.
@@ -101,24 +101,28 @@ pub fn interval_at(start: Instant, period: Duration) -> Interval {
assert!(period > Duration::new(0, 0), "`period` must be non-zero.");
Interval {
- delay: delay_until(start),
+ delay: sleep_until(start),
period,
}
}
/// Stream returned by [`interval`](interval) and [`interval_at`](interval_at).
+///
+/// This type only implements the [`Stream`] trait if the "stream" feature is
+/// enabled.
+///
+/// [`Stream`]: trait@crate::stream::Stream
#[derive(Debug)]
pub struct Interval {
/// Future that completes the next time the `Interval` yields a value.
- delay: Delay,
+ delay: Sleep,
/// The duration between values yielded by `Interval`.
period: Duration,
}
impl Interval {
- #[doc(hidden)] // TODO: document
- pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
+ fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
// Wait for the delay to be done
ready!(Pin::new(&mut self.delay).poll(cx));
@@ -154,7 +158,6 @@ impl Interval {
/// // approximately 20ms have elapsed.
/// }
/// ```
- #[allow(clippy::should_implement_trait)] // TODO: rename (tokio-rs/tokio#1261)
pub async fn tick(&mut self) -> Instant {
poll_fn(|cx| self.poll_tick(cx)).await
}
diff --git a/src/time/mod.rs b/src/time/mod.rs
index c532b2c..29af717 100644
--- a/src/time/mod.rs
+++ b/src/time/mod.rs
@@ -3,7 +3,7 @@
//! This module provides a number of types for executing code after a set period
//! of time.
//!
-//! * `Delay` is a future that does no work and completes at a specific `Instant`
+//! * `Sleep` is a future that does no work and completes at a specific `Instant`
//! in time.
//!
//! * `Interval` is a stream yielding a value at a fixed period. It is
@@ -14,9 +14,6 @@
//! of time it is allowed to execute. If the future or stream does not
//! complete in time, then it is canceled and an error is returned.
//!
-//! * `DelayQueue`: A queue where items are returned once the requested delay
-//! has expired.
-//!
//! These types are sufficient for handling a large number of scenarios
//! involving time.
//!
@@ -27,14 +24,14 @@
//! Wait 100ms and print "100 ms have elapsed"
//!
//! ```
-//! use tokio::time::delay_for;
+//! use tokio::time::sleep;
//!
//! use std::time::Duration;
//!
//!
//! #[tokio::main]
//! async fn main() {
-//! delay_for(Duration::from_millis(100)).await;
+//! sleep(Duration::from_millis(100)).await;
//! println!("100 ms have elapsed");
//! }
//! ```
@@ -61,12 +58,12 @@
//!
//! A simple example using [`interval`] to execute a task every two seconds.
//!
-//! The difference between [`interval`] and [`delay_for`] is that an
+//! The difference between [`interval`] and [`sleep`] is that an
//! [`interval`] measures the time since the last tick, which means that
//! `.tick().await` may wait for a shorter time than the duration specified
//! for the interval if some time has passed between calls to `.tick().await`.
//!
-//! If the tick in the example below was replaced with [`delay_for`], the task
+//! If the tick in the example below was replaced with [`sleep`], the task
//! would only be executed once every three seconds, and not every two
//! seconds.
//!
@@ -75,7 +72,7 @@
//!
//! async fn task_that_takes_a_second() {
//! println!("hello");
-//! time::delay_for(time::Duration::from_secs(1)).await
+//! time::sleep(time::Duration::from_secs(1)).await
//! }
//!
//! #[tokio::main]
@@ -88,7 +85,7 @@
//! }
//! ```
//!
-//! [`delay_for`]: crate::time::delay_for()
+//! [`sleep`]: crate::time::sleep()
//! [`interval`]: crate::time::interval()
mod clock;
@@ -96,17 +93,12 @@ pub(crate) use self::clock::Clock;
#[cfg(feature = "test-util")]
pub use clock::{advance, pause, resume};
-pub mod delay_queue;
-#[doc(inline)]
-pub use delay_queue::DelayQueue;
-
-mod delay;
-pub use delay::{delay_for, delay_until, Delay};
+mod sleep;
+pub use sleep::{sleep, sleep_until, Sleep};
pub(crate) mod driver;
-mod error;
-pub use error::Error;
+pub mod error;
mod instant;
pub use self::instant::Instant;
@@ -116,12 +108,7 @@ pub use interval::{interval, interval_at, Interval};
mod timeout;
#[doc(inline)]
-pub use timeout::{timeout, timeout_at, Elapsed, Timeout};
-
-cfg_stream! {
- mod throttle;
- pub use throttle::{throttle, Throttle};
-}
+pub use timeout::{timeout, timeout_at, Timeout};
mod wheel;
@@ -130,6 +117,7 @@ mod wheel;
mod tests;
// Re-export for convenience
+#[doc(no_inline)]
pub use std::time::Duration;
// ===== Internal utils =====
diff --git a/src/time/delay.rs b/src/time/sleep.rs
index 744c7e1..d3234a1 100644
--- a/src/time/delay.rs
+++ b/src/time/sleep.rs
@@ -1,31 +1,31 @@
-use crate::time::driver::Registration;
-use crate::time::{Duration, Instant};
+use crate::time::driver::{Entry, Handle};
+use crate::time::{error::Error, Duration, Instant};
use std::future::Future;
use std::pin::Pin;
+use std::sync::Arc;
use std::task::{self, Poll};
/// Waits until `deadline` is reached.
///
-/// No work is performed while awaiting on the delay to complete. The delay
+/// No work is performed while awaiting on the sleep future to complete. `Sleep`
/// operates at millisecond granularity and should not be used for tasks that
/// require high-resolution timers.
///
/// # Cancellation
///
-/// Canceling a delay is done by dropping the returned future. No additional
+/// Canceling a sleep instance is done by dropping the returned future. No additional
/// cleanup work is required.
-pub fn delay_until(deadline: Instant) -> Delay {
- let registration = Registration::new(deadline, Duration::from_millis(0));
- Delay { registration }
+pub fn sleep_until(deadline: Instant) -> Sleep {
+ Sleep::new_timeout(deadline, Duration::from_millis(0))
}
/// Waits until `duration` has elapsed.
///
-/// Equivalent to `delay_until(Instant::now() + duration)`. An asynchronous
+/// Equivalent to `sleep_until(Instant::now() + duration)`. An asynchronous
/// analog to `std::thread::sleep`.
///
-/// No work is performed while awaiting on the delay to complete. The delay
+/// No work is performed while awaiting on the sleep future to complete. `Sleep`
/// operates at millisecond granularity and should not be used for tasks that
/// require high-resolution timers.
///
@@ -33,7 +33,7 @@ pub fn delay_until(deadline: Instant) -> Delay {
///
/// # Cancellation
///
-/// Canceling a delay is done by dropping the returned future. No additional
+/// Canceling a sleep instance is done by dropping the returned future. No additional
/// cleanup work is required.
///
/// # Examples
@@ -41,78 +41,99 @@ pub fn delay_until(deadline: Instant) -> Delay {
/// Wait 100ms and print "100 ms have elapsed".
///
/// ```
-/// use tokio::time::{delay_for, Duration};
+/// use tokio::time::{sleep, Duration};
///
/// #[tokio::main]
/// async fn main() {
-/// delay_for(Duration::from_millis(100)).await;
+/// sleep(Duration::from_millis(100)).await;
/// println!("100 ms have elapsed");
/// }
/// ```
///
/// [`interval`]: crate::time::interval()
-#[cfg_attr(docsrs, doc(alias = "sleep"))]
-pub fn delay_for(duration: Duration) -> Delay {
- delay_until(Instant::now() + duration)
+pub fn sleep(duration: Duration) -> Sleep {
+ sleep_until(Instant::now() + duration)
}
-/// Future returned by [`delay_until`](delay_until) and
-/// [`delay_for`](delay_for).
+/// Future returned by [`sleep`](sleep) and
+/// [`sleep_until`](sleep_until).
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct Delay {
- /// The link between the `Delay` instance and the timer that drives it.
+pub struct Sleep {
+ /// The link between the `Sleep` instance and the timer that drives it.
///
/// This also stores the `deadline` value.
- registration: Registration,
+ entry: Arc<Entry>,
}
-impl Delay {
- pub(crate) fn new_timeout(deadline: Instant, duration: Duration) -> Delay {
- let registration = Registration::new(deadline, duration);
- Delay { registration }
+impl Sleep {
+ pub(crate) fn new_timeout(deadline: Instant, duration: Duration) -> Sleep {
+ let handle = Handle::current();
+ let entry = Entry::new(&handle, deadline, duration);
+
+ Sleep { entry }
}
/// Returns the instant at which the future will complete.
pub fn deadline(&self) -> Instant {
- self.registration.deadline()
+ self.entry.time_ref().deadline
}
- /// Returns `true` if the `Delay` has elapsed
+ /// Returns `true` if `Sleep` has elapsed.
///
- /// A `Delay` is elapsed when the requested duration has elapsed.
+ /// A `Sleep` instance is elapsed when the requested duration has elapsed.
pub fn is_elapsed(&self) -> bool {
- self.registration.is_elapsed()
+ self.entry.is_elapsed()
}
- /// Resets the `Delay` instance to a new deadline.
+ /// Resets the `Sleep` instance to a new deadline.
///
- /// Calling this function allows changing the instant at which the `Delay`
+ /// Calling this function allows changing the instant at which the `Sleep`
/// future completes without having to create new associated state.
///
/// This function can be called both before and after the future has
/// completed.
pub fn reset(&mut self, deadline: Instant) {
- self.registration.reset(deadline);
+ unsafe {
+ self.entry.time_mut().deadline = deadline;
+ }
+
+ Entry::reset(&mut self.entry);
+ }
+
+ fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
+ // Keep track of task budget
+ let coop = ready!(crate::coop::poll_proceed(cx));
+
+ self.entry.poll_elapsed(cx).map(move |r| {
+ coop.made_progress();
+ r
+ })
}
}
-impl Future for Delay {
+impl Future for Sleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
// `poll_elapsed` can return an error in two cases:
//
- // - AtCapacity: this is a pathlogical case where far too many
- // delays have been scheduled.
+ // - AtCapacity: this is a pathological case where far too many
+ // sleep instances have been scheduled.
// - Shutdown: No timer has been setup, which is a mis-use error.
//
// Both cases are extremely rare, and pretty accurately fit into
// "logic errors", so we just panic in this case. A user couldn't
// really do much better if we passed the error onwards.
- match ready!(self.registration.poll_elapsed(cx)) {
+ match ready!(self.poll_elapsed(cx)) {
Ok(()) => Poll::Ready(()),
Err(e) => panic!("timer error: {}", e),
}
}
}
+
+impl Drop for Sleep {
+ fn drop(&mut self) {
+ Entry::cancel(&self.entry);
+ }
+}
diff --git a/src/time/tests/mod.rs b/src/time/tests/mod.rs
index 4710d47..fae67da 100644
--- a/src/time/tests/mod.rs
+++ b/src/time/tests/mod.rs
@@ -1,4 +1,4 @@
-mod test_delay;
+mod test_sleep;
use crate::time::{self, Instant};
use std::time::Duration;
@@ -8,15 +8,15 @@ fn assert_sync<T: Sync>() {}
#[test]
fn registration_is_send_and_sync() {
- use crate::time::driver::Registration;
+ use crate::time::sleep::Sleep;
- assert_send::<Registration>();
- assert_sync::<Registration>();
+ assert_send::<Sleep>();
+ assert_sync::<Sleep>();
}
#[test]
#[should_panic]
-fn delay_is_eager() {
+fn sleep_is_eager() {
let when = Instant::now() + Duration::from_millis(100);
- let _ = time::delay_until(when);
+ let _ = time::sleep_until(when);
}
diff --git a/src/time/tests/test_delay.rs b/src/time/tests/test_sleep.rs
index f843434..c8d931a 100644
--- a/src/time/tests/test_delay.rs
+++ b/src/time/tests/test_sleep.rs
@@ -1,5 +1,3 @@
-#![warn(rust_2018_idioms)]
-
use crate::park::{Park, Unpark};
use crate::time::driver::{Driver, Entry, Handle};
use crate::time::Clock;
@@ -27,12 +25,12 @@ fn frozen_utility_returns_correct_advanced_duration() {
}
#[test]
-fn immediate_delay() {
+fn immediate_sleep() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
let when = clock.now();
- let mut e = task::spawn(delay_until(&handle, when));
+ let mut e = task::spawn(sleep_until(&handle, when));
assert_ready_ok!(poll!(e));
@@ -43,15 +41,15 @@ fn immediate_delay() {
}
#[test]
-fn delayed_delay_level_0() {
+fn delayed_sleep_level_0() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
for &i in &[1, 10, 60] {
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, start + ms(i)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, start + ms(i)));
- // The delay has not elapsed.
+ // The sleep instance has not elapsed.
assert_pending!(poll!(e));
assert_ok!(driver.park());
@@ -62,13 +60,13 @@ fn delayed_delay_level_0() {
}
#[test]
-fn sub_ms_delayed_delay() {
+fn sub_ms_delayed_sleep() {
let (mut driver, clock, handle) = setup();
for _ in 0..5 {
let deadline = clock.now() + ms(1) + Duration::new(0, 1);
- let mut e = task::spawn(delay_until(&handle, deadline));
+ let mut e = task::spawn(sleep_until(&handle, deadline));
assert_pending!(poll!(e));
@@ -82,14 +80,14 @@ fn sub_ms_delayed_delay() {
}
#[test]
-fn delayed_delay_wrapping_level_0() {
+fn delayed_sleep_wrapping_level_0() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
assert_ok!(driver.park_timeout(ms(5)));
assert_eq!(clock.now() - start, ms(5));
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(60)));
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(60)));
assert_pending!(poll!(e));
@@ -108,15 +106,15 @@ fn timer_wrapping_with_higher_levels() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Set delay to hit level 1
- let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(64)));
+ // Set sleep to hit level 1
+ let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(64)));
assert_pending!(poll!(e1));
// Turn a bit
assert_ok!(driver.park_timeout(ms(5)));
// Set timeout such that it will hit level 0, but wrap
- let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(60)));
+ let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(60)));
assert_pending!(poll!(e2));
// This should result in s1 firing
@@ -133,14 +131,14 @@ fn timer_wrapping_with_higher_levels() {
}
#[test]
-fn delay_with_deadline_in_past() {
+fn sleep_with_deadline_in_past() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create `Delay` that elapsed immediately.
- let mut e = task::spawn(delay_until(&handle, clock.now() - ms(100)));
+ // Create `Sleep` that elapsed immediately.
+ let mut e = task::spawn(sleep_until(&handle, clock.now() - ms(100)));
- // Even though the delay expires in the past, it is not ready yet
+ // Even though the `Sleep` expires in the past, it is not ready yet
// because the timer must observe it.
assert_ready_ok!(poll!(e));
@@ -152,37 +150,37 @@ fn delay_with_deadline_in_past() {
}
#[test]
-fn delayed_delay_level_1() {
+fn delayed_sleep_level_1() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(234)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234)));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
// Turn the timer, this will wake up to cascade the timer down.
assert_ok!(driver.park_timeout(ms(1000)));
assert_eq!(clock.now() - start, ms(192));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
// Turn the timer again
assert_ok!(driver.park_timeout(ms(1000)));
assert_eq!(clock.now() - start, ms(234));
- // The delay has elapsed.
+ // The sleep has elapsed.
assert_ready_ok!(poll!(e));
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(234)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(234)));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
// Turn the timer with a smaller timeout than the cascade.
@@ -195,14 +193,14 @@ fn delayed_delay_level_1() {
assert_ok!(driver.park_timeout(ms(1000)));
assert_eq!(clock.now() - start, ms(192));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
// Turn the timer again
assert_ok!(driver.park_timeout(ms(1000)));
assert_eq!(clock.now() - start, ms(234));
- // The delay has elapsed.
+ // The sleep has elapsed.
assert_ready_ok!(poll!(e));
}
@@ -211,22 +209,22 @@ fn concurrently_set_two_timers_second_one_shorter() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
- let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(500)));
- let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(200)));
+ let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(500)));
+ let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(200)));
- // The delay has not elapsed
+ // The sleep has not elapsed
assert_pending!(poll!(e1));
assert_pending!(poll!(e2));
- // Delay until a cascade
+ // Sleep until a cascade
assert_ok!(driver.park());
assert_eq!(clock.now() - start, ms(192));
- // Delay until the second timer.
+ // Sleep until the second timer.
assert_ok!(driver.park());
assert_eq!(clock.now() - start, ms(200));
- // The shorter delay fires
+ // The shorter sleep fires
assert_ready_ok!(poll!(e2));
assert_pending!(poll!(e1));
@@ -235,7 +233,7 @@ fn concurrently_set_two_timers_second_one_shorter() {
assert_pending!(poll!(e1));
- // Turn again, this time the time will advance to the second delay
+ // Turn again, this time the time will advance to the second sleep
assert_ok!(driver.park());
assert_eq!(clock.now() - start, ms(500));
@@ -243,37 +241,37 @@ fn concurrently_set_two_timers_second_one_shorter() {
}
#[test]
-fn short_delay() {
+fn short_sleep() {
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(1)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1)));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
// Turn the timer, but not enough time will go by.
assert_ok!(driver.park());
- // The delay has elapsed.
+ // The sleep has elapsed.
assert_ready_ok!(poll!(e));
- // The time has advanced to the point of the delay elapsing.
+ // The time has advanced to the point of the sleep elapsing.
assert_eq!(clock.now() - start, ms(1));
}
#[test]
-fn sorta_long_delay_until() {
+fn sorta_long_sleep_until() {
const MIN_5: u64 = 5 * 60 * 1000;
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(MIN_5)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MIN_5)));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
let cascades = &[262_144, 262_144 + 9 * 4096, 262_144 + 9 * 4096 + 15 * 64];
@@ -288,21 +286,21 @@ fn sorta_long_delay_until() {
assert_ok!(driver.park());
assert_eq!(clock.now() - start, ms(MIN_5));
- // The delay has elapsed.
+ // The sleep has elapsed.
assert_ready_ok!(poll!(e));
}
#[test]
-fn very_long_delay() {
+fn very_long_sleep() {
const MO_5: u64 = 5 * 30 * 24 * 60 * 60 * 1000;
let (mut driver, clock, handle) = setup();
let start = clock.now();
- // Create a `Delay` that elapses in the future
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(MO_5)));
+ // Create a `Sleep` that elapses in the future
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(MO_5)));
- // The delay has not elapsed.
+ // The sleep has not elapsed.
assert_pending!(poll!(e));
let cascades = &[
@@ -322,10 +320,10 @@ fn very_long_delay() {
// Turn the timer, but not enough time will go by.
assert_ok!(driver.park());
- // The time has advanced to the point of the delay elapsing.
+ // The time has advanced to the point of the sleep elapsing.
assert_eq!(clock.now() - start, ms(MO_5));
- // The delay has elapsed.
+ // The sleep has elapsed.
assert_ready_ok!(poll!(e));
}
@@ -353,6 +351,8 @@ fn unpark_is_delayed() {
self.0.advance(ms(436));
Ok(())
}
+
+ fn shutdown(&mut self) {}
}
impl Unpark for MockUnpark {
@@ -365,9 +365,9 @@ fn unpark_is_delayed() {
let mut driver = Driver::new(MockPark(clock.clone()), clock.clone());
let handle = driver.handle();
- let mut e1 = task::spawn(delay_until(&handle, clock.now() + ms(100)));
- let mut e2 = task::spawn(delay_until(&handle, clock.now() + ms(101)));
- let mut e3 = task::spawn(delay_until(&handle, clock.now() + ms(200)));
+ let mut e1 = task::spawn(sleep_until(&handle, clock.now() + ms(100)));
+ let mut e2 = task::spawn(sleep_until(&handle, clock.now() + ms(101)));
+ let mut e3 = task::spawn(sleep_until(&handle, clock.now() + ms(200)));
assert_pending!(poll!(e1));
assert_pending!(poll!(e2));
@@ -394,7 +394,7 @@ fn set_timeout_at_deadline_greater_than_max_timer() {
assert_ok!(driver.park_timeout(ms(YR_1)));
}
- let mut e = task::spawn(delay_until(&handle, clock.now() + ms(1)));
+ let mut e = task::spawn(sleep_until(&handle, clock.now() + ms(1)));
assert_pending!(poll!(e));
assert_ok!(driver.park_timeout(ms(1000)));
@@ -412,7 +412,7 @@ fn setup() -> (Driver<MockPark>, Clock, Handle) {
(driver, clock, handle)
}
-fn delay_until(handle: &Handle, when: Instant) -> Arc<Entry> {
+fn sleep_until(handle: &Handle, when: Instant) -> Arc<Entry> {
Entry::new(&handle, when, ms(0))
}
@@ -436,6 +436,8 @@ impl Park for MockPark {
self.0.advance(duration);
Ok(())
}
+
+ fn shutdown(&mut self) {}
}
impl Unpark for MockUnpark {
diff --git a/src/time/timeout.rs b/src/time/timeout.rs
index efc3dc5..cf09b07 100644
--- a/src/time/timeout.rs
+++ b/src/time/timeout.rs
@@ -4,10 +4,9 @@
//!
//! [`Timeout`]: struct@Timeout
-use crate::time::{delay_until, Delay, Duration, Instant};
+use crate::time::{error::Elapsed, sleep_until, Duration, Instant, Sleep};
use pin_project_lite::pin_project;
-use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{self, Poll};
@@ -50,7 +49,7 @@ pub fn timeout<T>(duration: Duration, future: T) -> Timeout<T>
where
T: Future,
{
- let delay = Delay::new_timeout(Instant::now() + duration, duration);
+ let delay = Sleep::new_timeout(Instant::now() + duration, duration);
Timeout::new_with_delay(future, delay)
}
@@ -92,7 +91,7 @@ pub fn timeout_at<T>(deadline: Instant, future: T) -> Timeout<T>
where
T: Future,
{
- let delay = delay_until(deadline);
+ let delay = sleep_until(deadline);
Timeout {
value: future,
@@ -108,24 +107,12 @@ pin_project! {
#[pin]
value: T,
#[pin]
- delay: Delay,
- }
-}
-
-/// Error returned by `Timeout`.
-#[derive(Debug, PartialEq)]
-pub struct Elapsed(());
-
-impl Elapsed {
- // Used on StreamExt::timeout
- #[allow(unused)]
- pub(crate) fn new() -> Self {
- Elapsed(())
+ delay: Sleep,
}
}
impl<T> Timeout<T> {
- pub(crate) fn new_with_delay(value: T, delay: Delay) -> Timeout<T> {
+ pub(crate) fn new_with_delay(value: T, delay: Sleep) -> Timeout<T> {
Timeout { value, delay }
}
@@ -161,24 +148,8 @@ where
// Now check the timer
match me.delay.poll(cx) {
- Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))),
+ Poll::Ready(()) => Poll::Ready(Err(Elapsed::new())),
Poll::Pending => Poll::Pending,
}
}
}
-
-// ===== impl Elapsed =====
-
-impl fmt::Display for Elapsed {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- "deadline has elapsed".fmt(fmt)
- }
-}
-
-impl std::error::Error for Elapsed {}
-
-impl From<Elapsed> for std::io::Error {
- fn from(_err: Elapsed) -> std::io::Error {
- std::io::ErrorKind::TimedOut.into()
- }
-}
diff --git a/src/time/wheel/level.rs b/src/time/wheel/level.rs
index 49f9bfb..d51d26a 100644
--- a/src/time/wheel/level.rs
+++ b/src/time/wheel/level.rs
@@ -1,9 +1,10 @@
+use super::{Item, OwnedItem};
use crate::time::wheel::Stack;
use std::fmt;
/// Wheel for a single level in the timer. This wheel contains 64 slots.
-pub(crate) struct Level<T> {
+pub(crate) struct Level {
level: usize,
/// Bit field tracking which slots currently contain entries.
@@ -16,7 +17,7 @@ pub(crate) struct Level<T> {
occupied: u64,
/// Slots
- slot: [T; LEVEL_MULT],
+ slot: [Stack; LEVEL_MULT],
}
/// Indicates when a slot must be processed next.
@@ -37,87 +38,90 @@ pub(crate) struct Expiration {
/// Being a power of 2 is very important.
const LEVEL_MULT: usize = 64;
-impl<T: Stack> Level<T> {
- pub(crate) fn new(level: usize) -> Level<T> {
- // Rust's derived implementations for arrays require that the value
- // contained by the array be `Copy`. So, here we have to manually
- // initialize every single slot.
- macro_rules! s {
- () => {
- T::default()
- };
- };
+impl Level {
+ pub(crate) fn new(level: usize) -> Level {
+ // A value has to be Copy in order to use syntax like:
+ // let stack = Stack::default();
+ // ...
+ // slots: [stack; 64],
+ //
+ // Alternatively, since Stack is Default one can
+ // use syntax like:
+ // let slots: [Stack; 64] = Default::default();
+ //
+ // However, that is only supported for arrays of size
+ // 32 or fewer. So in our case we have to explicitly
+ // invoke the constructor for each array element.
+ let ctor = Stack::default;
Level {
level,
occupied: 0,
slot: [
- // It does not look like the necessary traits are
- // derived for [T; 64].
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
- s!(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
+ ctor(),
],
}
}
@@ -173,17 +177,17 @@ impl<T: Stack> Level<T> {
Some(slot)
}
- pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) {
+ pub(crate) fn add_entry(&mut self, when: u64, item: OwnedItem) {
let slot = slot_for(when, self.level);
- self.slot[slot].push(item, store);
+ self.slot[slot].push(item);
self.occupied |= occupied_bit(slot);
}
- pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) {
+ pub(crate) fn remove_entry(&mut self, when: u64, item: &Item) {
let slot = slot_for(when, self.level);
- self.slot[slot].remove(item, store);
+ self.slot[slot].remove(item);
if self.slot[slot].is_empty() {
// The bit is currently set
@@ -194,8 +198,8 @@ impl<T: Stack> Level<T> {
}
}
- pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> {
- let ret = self.slot[slot].pop(store);
+ pub(crate) fn pop_entry_slot(&mut self, slot: usize) -> Option<OwnedItem> {
+ let ret = self.slot[slot].pop();
if ret.is_some() && self.slot[slot].is_empty() {
// The bit is currently set
@@ -208,7 +212,7 @@ impl<T: Stack> Level<T> {
}
}
-impl<T> fmt::Debug for Level<T> {
+impl fmt::Debug for Level {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Level")
.field("occupied", &self.occupied)
diff --git a/src/time/wheel/mod.rs b/src/time/wheel/mod.rs
index a2ef27f..85ed2f1 100644
--- a/src/time/wheel/mod.rs
+++ b/src/time/wheel/mod.rs
@@ -1,3 +1,5 @@
+use crate::time::{driver::Entry, error::InsertError};
+
mod level;
pub(crate) use self::level::Expiration;
use self::level::Level;
@@ -5,9 +7,12 @@ use self::level::Level;
mod stack;
pub(crate) use self::stack::Stack;
-use std::borrow::Borrow;
+use std::sync::Arc;
use std::usize;
+pub(super) type Item = Entry;
+pub(super) type OwnedItem = Arc<Item>;
+
/// Timing wheel implementation.
///
/// This type provides the hashed timing wheel implementation that backs `Timer`
@@ -20,7 +25,7 @@ use std::usize;
///
/// See `Timer` documentation for some implementation notes.
#[derive(Debug)]
-pub(crate) struct Wheel<T> {
+pub(crate) struct Wheel {
/// The number of milliseconds elapsed since the wheel started.
elapsed: u64,
@@ -34,7 +39,7 @@ pub(crate) struct Wheel<T> {
/// * ~ 4 min slots / ~ 4 hr range
/// * ~ 4 hr slots / ~ 12 day range
/// * ~ 12 day slots / ~ 2 yr range
- levels: Vec<Level<T>>,
+ levels: Vec<Level>,
}
/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots
@@ -42,28 +47,12 @@ pub(crate) struct Wheel<T> {
/// precision of 1 millisecond.
const NUM_LEVELS: usize = 6;
-/// The maximum duration of a delay
+/// The maximum duration of a `Sleep`
const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1;
-#[derive(Debug)]
-pub(crate) enum InsertError {
- Elapsed,
- Invalid,
-}
-
-/// Poll expirations from the wheel
-#[derive(Debug, Default)]
-pub(crate) struct Poll {
- now: u64,
- expiration: Option<Expiration>,
-}
-
-impl<T> Wheel<T>
-where
- T: Stack,
-{
+impl Wheel {
/// Create a new timing wheel
- pub(crate) fn new() -> Wheel<T> {
+ pub(crate) fn new() -> Wheel {
let levels = (0..NUM_LEVELS).map(Level::new).collect();
Wheel { elapsed: 0, levels }
@@ -99,9 +88,8 @@ where
pub(crate) fn insert(
&mut self,
when: u64,
- item: T::Owned,
- store: &mut T::Store,
- ) -> Result<(), (T::Owned, InsertError)> {
+ item: OwnedItem,
+ ) -> Result<(), (OwnedItem, InsertError)> {
if when <= self.elapsed {
return Err((item, InsertError::Elapsed));
} else if when - self.elapsed > MAX_DURATION {
@@ -111,7 +99,7 @@ where
// Get the level at which the entry should be stored
let level = self.level_for(when);
- self.levels[level].add_entry(when, item, store);
+ self.levels[level].add_entry(when, item);
debug_assert!({
self.levels[level]
@@ -124,11 +112,11 @@ where
}
/// Remove `item` from thee timing wheel.
- pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) {
- let when = T::when(item, store);
+ pub(crate) fn remove(&mut self, item: &Item) {
+ let when = item.when();
let level = self.level_for(when);
- self.levels[level].remove_entry(when, item, store);
+ self.levels[level].remove_entry(when, item);
}
/// Instant at which to poll
@@ -136,33 +124,35 @@ where
self.next_expiration().map(|expiration| expiration.deadline)
}
- pub(crate) fn poll(&mut self, poll: &mut Poll, store: &mut T::Store) -> Option<T::Owned> {
+ /// Advances the timer up to the instant represented by `now`.
+ pub(crate) fn poll(&mut self, now: u64) -> Option<OwnedItem> {
loop {
- if poll.expiration.is_none() {
- poll.expiration = self.next_expiration().and_then(|expiration| {
- if expiration.deadline > poll.now {
- None
- } else {
- Some(expiration)
- }
- });
- }
+ // under what circumstances is poll.expiration Some vs. None?
+ let expiration = self.next_expiration().and_then(|expiration| {
+ if expiration.deadline > now {
+ None
+ } else {
+ Some(expiration)
+ }
+ });
- match poll.expiration {
+ match expiration {
Some(ref expiration) => {
- if let Some(item) = self.poll_expiration(expiration, store) {
+ if let Some(item) = self.poll_expiration(expiration) {
return Some(item);
}
self.set_elapsed(expiration.deadline);
}
None => {
- self.set_elapsed(poll.now);
+ // in this case the poll did not indicate an expiration
+ // _and_ we were not able to find a next expiration in
+ // the current list of timers. advance to the poll's
+ // current time and do nothing else.
+ self.set_elapsed(now);
return None;
}
}
-
- poll.expiration = None;
}
}
@@ -197,22 +187,22 @@ where
res
}
- pub(crate) fn poll_expiration(
- &mut self,
- expiration: &Expiration,
- store: &mut T::Store,
- ) -> Option<T::Owned> {
- while let Some(item) = self.pop_entry(expiration, store) {
+ /// iteratively find entries that are between the wheel's current
+ /// time and the expiration time. for each in that population either
+ /// return it for notification (in the case of the last level) or tier
+ /// it down to the next level (in all other cases).
+ pub(crate) fn poll_expiration(&mut self, expiration: &Expiration) -> Option<OwnedItem> {
+ while let Some(item) = self.pop_entry(expiration) {
if expiration.level == 0 {
- debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline);
+ debug_assert_eq!(item.when(), expiration.deadline);
return Some(item);
} else {
- let when = T::when(item.borrow(), store);
+ let when = item.when();
let next_level = expiration.level - 1;
- self.levels[next_level].add_entry(when, item, store);
+ self.levels[next_level].add_entry(when, item);
}
}
@@ -232,8 +222,8 @@ where
}
}
- fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> {
- self.levels[expiration.level].pop_entry_slot(expiration.slot, store)
+ fn pop_entry(&mut self, expiration: &Expiration) -> Option<OwnedItem> {
+ self.levels[expiration.level].pop_entry_slot(expiration.slot)
}
fn level_for(&self, when: u64) -> usize {
@@ -251,15 +241,6 @@ fn level_for(elapsed: u64, when: u64) -> usize {
significant / 6
}
-impl Poll {
- pub(crate) fn new(now: u64) -> Poll {
- Poll {
- now,
- expiration: None,
- }
- }
-}
-
#[cfg(all(test, not(loom)))]
mod test {
use super::*;
diff --git a/src/time/wheel/stack.rs b/src/time/wheel/stack.rs
index 6e55c38..e7ed137 100644
--- a/src/time/wheel/stack.rs
+++ b/src/time/wheel/stack.rs
@@ -1,26 +1,112 @@
-use std::borrow::Borrow;
+use super::{Item, OwnedItem};
+use crate::time::driver::Entry;
-/// Abstracts the stack operations needed to track timeouts.
-pub(crate) trait Stack: Default {
- /// Type of the item stored in the stack
- type Owned: Borrow<Self::Borrowed>;
+use std::ptr;
- /// Borrowed item
- type Borrowed;
+/// A doubly linked stack
+#[derive(Debug)]
+pub(crate) struct Stack {
+ head: Option<OwnedItem>,
+}
+
+impl Default for Stack {
+ fn default() -> Stack {
+ Stack { head: None }
+ }
+}
+
+impl Stack {
+ pub(crate) fn is_empty(&self) -> bool {
+ self.head.is_none()
+ }
+
+ pub(crate) fn push(&mut self, entry: OwnedItem) {
+ // Get a pointer to the entry to for the prev link
+ let ptr: *const Entry = &*entry as *const _;
+
+ // Remove the old head entry
+ let old = self.head.take();
+
+ unsafe {
+ // Ensure the entry is not already in a stack.
+ debug_assert!((*entry.next_stack.get()).is_none());
+ debug_assert!((*entry.prev_stack.get()).is_null());
+
+ if let Some(ref entry) = old.as_ref() {
+ debug_assert!({
+ // The head is not already set to the entry
+ ptr != &***entry as *const _
+ });
+
+ // Set the previous link on the old head
+ *entry.prev_stack.get() = ptr;
+ }
+
+ // Set this entry's next pointer
+ *entry.next_stack.get() = old;
+ }
+
+ // Update the head pointer
+ self.head = Some(entry);
+ }
+
+ /// Pops an item from the stack
+ pub(crate) fn pop(&mut self) -> Option<OwnedItem> {
+ let entry = self.head.take();
+
+ unsafe {
+ if let Some(entry) = entry.as_ref() {
+ self.head = (*entry.next_stack.get()).take();
+
+ if let Some(entry) = self.head.as_ref() {
+ *entry.prev_stack.get() = ptr::null();
+ }
+
+ *entry.prev_stack.get() = ptr::null();
+ }
+ }
+
+ entry
+ }
+
+ pub(crate) fn remove(&mut self, entry: &Item) {
+ unsafe {
+ // Ensure that the entry is in fact contained by the stack
+ debug_assert!({
+ // This walks the full linked list even if an entry is found.
+ let mut next = self.head.as_ref();
+ let mut contains = false;
+
+ while let Some(n) = next {
+ if entry as *const _ == &**n as *const _ {
+ debug_assert!(!contains);
+ contains = true;
+ }
+
+ next = (*n.next_stack.get()).as_ref();
+ }
- /// Item storage, this allows a slab to be used instead of just the heap
- type Store;
+ contains
+ });
- /// Returns `true` if the stack is empty
- fn is_empty(&self) -> bool;
+ // Unlink `entry` from the next node
+ let next = (*entry.next_stack.get()).take();
- /// Push an item onto the stack
- fn push(&mut self, item: Self::Owned, store: &mut Self::Store);
+ if let Some(next) = next.as_ref() {
+ (*next.prev_stack.get()) = *entry.prev_stack.get();
+ }
- /// Pop an item from the stack
- fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>;
+ // Unlink `entry` from the prev node
- fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store);
+ if let Some(prev) = (*entry.prev_stack.get()).as_ref() {
+ *prev.next_stack.get() = next;
+ } else {
+ // It is the head
+ self.head = next;
+ }
- fn when(item: &Self::Borrowed, store: &Self::Store) -> u64;
+ // Unset the prev pointer
+ *entry.prev_stack.get() = ptr::null();
+ }
+ }
}
diff --git a/src/util/bit.rs b/src/util/bit.rs
index e61ac21..392a0e8 100644
--- a/src/util/bit.rs
+++ b/src/util/bit.rs
@@ -1,22 +1,12 @@
use std::fmt;
-#[derive(Clone, Copy)]
+#[derive(Clone, Copy, PartialEq)]
pub(crate) struct Pack {
mask: usize,
shift: u32,
}
impl Pack {
- /// Value is packed in the `width` most-significant bits.
- pub(crate) const fn most_significant(width: u32) -> Pack {
- let mask = mask_for(width).reverse_bits();
-
- Pack {
- mask,
- shift: mask.trailing_zeros(),
- }
- }
-
/// Value is packed in the `width` least-significant bits.
pub(crate) const fn least_significant(width: u32) -> Pack {
let mask = mask_for(width);
@@ -32,12 +22,6 @@ impl Pack {
Pack { mask, shift }
}
- /// Mask used to unpack value
- #[cfg(all(test, loom))]
- pub(crate) const fn mask(&self) -> usize {
- self.mask
- }
-
/// Width, in bits, dedicated to storing the value.
pub(crate) const fn width(&self) -> u32 {
pointer_width() - (self.mask >> self.shift).leading_zeros()
@@ -53,6 +37,14 @@ impl Pack {
(base & !self.mask) | (value << self.shift)
}
+ /// Packs the value with `base`, losing any bits of `value` that fit.
+ ///
+ /// If `value` is larger than the max value that can be represented by the
+ /// allotted width, the most significant bits are truncated.
+ pub(crate) fn pack_lossy(&self, value: usize, base: usize) -> usize {
+ self.pack(value & self.max_value(), base)
+ }
+
pub(crate) fn unpack(&self, src: usize) -> usize {
unpack(src, self.mask, self.shift)
}
diff --git a/src/util/intrusive_double_linked_list.rs b/src/util/intrusive_double_linked_list.rs
deleted file mode 100644
index 083fa31..0000000
--- a/src/util/intrusive_double_linked_list.rs
+++ /dev/null
@@ -1,788 +0,0 @@
-//! An intrusive double linked list of data
-
-#![allow(dead_code, unreachable_pub)]
-
-use core::{
- marker::PhantomPinned,
- ops::{Deref, DerefMut},
- ptr::NonNull,
-};
-
-/// A node which carries data of type `T` and is stored in an intrusive list
-#[derive(Debug)]
-pub struct ListNode<T> {
- /// The previous node in the list. `None` if there is no previous node.
- prev: Option<NonNull<ListNode<T>>>,
- /// The next node in the list. `None` if there is no previous node.
- next: Option<NonNull<ListNode<T>>>,
- /// The data which is associated to this list item
- data: T,
- /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since
- /// the list semantics require addresses to be stable.
- _pin: PhantomPinned,
-}
-
-impl<T> ListNode<T> {
- /// Creates a new node with the associated data
- pub fn new(data: T) -> ListNode<T> {
- Self {
- prev: None,
- next: None,
- data,
- _pin: PhantomPinned,
- }
- }
-}
-
-impl<T> Deref for ListNode<T> {
- type Target = T;
-
- fn deref(&self) -> &T {
- &self.data
- }
-}
-
-impl<T> DerefMut for ListNode<T> {
- fn deref_mut(&mut self) -> &mut T {
- &mut self.data
- }
-}
-
-/// An intrusive linked list of nodes, where each node carries associated data
-/// of type `T`.
-#[derive(Debug)]
-pub struct LinkedList<T> {
- head: Option<NonNull<ListNode<T>>>,
- tail: Option<NonNull<ListNode<T>>>,
-}
-
-impl<T> LinkedList<T> {
- /// Creates an empty linked list
- pub fn new() -> Self {
- LinkedList::<T> {
- head: None,
- tail: None,
- }
- }
-
- /// Adds a node at the front of the linked list.
- /// Safety: This function is only safe as long as `node` is guaranteed to
- /// get removed from the list before it gets moved or dropped.
- /// In addition to this `node` may not be added to another other list before
- /// it is removed from the current one.
- pub unsafe fn add_front(&mut self, node: &mut ListNode<T>) {
- node.next = self.head;
- node.prev = None;
- if let Some(mut head) = self.head {
- head.as_mut().prev = Some(node.into())
- };
- self.head = Some(node.into());
- if self.tail.is_none() {
- self.tail = Some(node.into());
- }
- }
-
- /// Inserts a node into the list in a way that the list keeps being sorted.
- /// Safety: This function is only safe as long as `node` is guaranteed to
- /// get removed from the list before it gets moved or dropped.
- /// In addition to this `node` may not be added to another other list before
- /// it is removed from the current one.
- pub unsafe fn add_sorted(&mut self, node: &mut ListNode<T>)
- where
- T: PartialOrd,
- {
- if self.head.is_none() {
- // First node in the list
- self.head = Some(node.into());
- self.tail = Some(node.into());
- return;
- }
-
- let mut prev: Option<NonNull<ListNode<T>>> = None;
- let mut current = self.head;
-
- while let Some(mut current_node) = current {
- if node.data < current_node.as_ref().data {
- // Need to insert before the current node
- current_node.as_mut().prev = Some(node.into());
- match prev {
- Some(mut prev) => {
- prev.as_mut().next = Some(node.into());
- }
- None => {
- // We are inserting at the beginning of the list
- self.head = Some(node.into());
- }
- }
- node.next = current;
- node.prev = prev;
- return;
- }
- prev = current;
- current = current_node.as_ref().next;
- }
-
- // We looped through the whole list and the nodes data is bigger or equal
- // than everything we found up to now.
- // Insert at the end. Since we checked before that the list isn't empty,
- // tail always has a value.
- node.prev = self.tail;
- node.next = None;
- self.tail.as_mut().unwrap().as_mut().next = Some(node.into());
- self.tail = Some(node.into());
- }
-
- /// Returns the first node in the linked list without removing it from the list
- /// The function is only safe as long as valid pointers are stored inside
- /// the linked list.
- /// The returned pointer is only guaranteed to be valid as long as the list
- /// is not mutated
- pub fn peek_first(&self) -> Option<&mut ListNode<T>> {
- // Safety: When the node was inserted it was promised that it is alive
- // until it gets removed from the list.
- // The returned node has a pointer which constrains it to the lifetime
- // of the list. This is ok, since the Node is supposed to outlive
- // its insertion in the list.
- unsafe {
- self.head
- .map(|mut node| &mut *(node.as_mut() as *mut ListNode<T>))
- }
- }
-
- /// Returns the last node in the linked list without removing it from the list
- /// The function is only safe as long as valid pointers are stored inside
- /// the linked list.
- /// The returned pointer is only guaranteed to be valid as long as the list
- /// is not mutated
- pub fn peek_last(&self) -> Option<&mut ListNode<T>> {
- // Safety: When the node was inserted it was promised that it is alive
- // until it gets removed from the list.
- // The returned node has a pointer which constrains it to the lifetime
- // of the list. This is ok, since the Node is supposed to outlive
- // its insertion in the list.
- unsafe {
- self.tail
- .map(|mut node| &mut *(node.as_mut() as *mut ListNode<T>))
- }
- }
-
- /// Removes the first node from the linked list
- pub fn remove_first(&mut self) -> Option<&mut ListNode<T>> {
- #![allow(clippy::debug_assert_with_mut_call)]
-
- // Safety: When the node was inserted it was promised that it is alive
- // until it gets removed from the list
- unsafe {
- let mut head = self.head?;
- self.head = head.as_mut().next;
-
- let first_ref = head.as_mut();
- match first_ref.next {
- None => {
- // This was the only node in the list
- debug_assert_eq!(Some(first_ref.into()), self.tail);
- self.tail = None;
- }
- Some(mut next) => {
- next.as_mut().prev = None;
- }
- }
-
- first_ref.prev = None;
- first_ref.next = None;
- Some(&mut *(first_ref as *mut ListNode<T>))
- }
- }
-
- /// Removes the last node from the linked list and returns it
- pub fn remove_last(&mut self) -> Option<&mut ListNode<T>> {
- #![allow(clippy::debug_assert_with_mut_call)]
-
- // Safety: When the node was inserted it was promised that it is alive
- // until it gets removed from the list
- unsafe {
- let mut tail = self.tail?;
- self.tail = tail.as_mut().prev;
-
- let last_ref = tail.as_mut();
- match last_ref.prev {
- None => {
- // This was the last node in the list
- debug_assert_eq!(Some(last_ref.into()), self.head);
- self.head = None;
- }
- Some(mut prev) => {
- prev.as_mut().next = None;
- }
- }
-
- last_ref.prev = None;
- last_ref.next = None;
- Some(&mut *(last_ref as *mut ListNode<T>))
- }
- }
-
- /// Returns whether the linked list doesn not contain any node
- pub fn is_empty(&self) -> bool {
- if self.head.is_some() {
- return false;
- }
-
- debug_assert!(self.tail.is_none());
- true
- }
-
- /// Removes the given `node` from the linked list.
- /// Returns whether the `node` was removed.
- /// It is also only safe if it is known that the `node` is either part of this
- /// list, or of no list at all. If `node` is part of another list, the
- /// behavior is undefined.
- pub unsafe fn remove(&mut self, node: &mut ListNode<T>) -> bool {
- #![allow(clippy::debug_assert_with_mut_call)]
-
- match node.prev {
- None => {
- // This might be the first node in the list. If it is not, the
- // node is not in the list at all. Since our precondition is that
- // the node must either be in this list or in no list, we check that
- // the node is really in no list.
- if self.head != Some(node.into()) {
- debug_assert!(node.next.is_none());
- return false;
- }
- self.head = node.next;
- }
- Some(mut prev) => {
- debug_assert_eq!(prev.as_ref().next, Some(node.into()));
- prev.as_mut().next = node.next;
- }
- }
-
- match node.next {
- None => {
- // This must be the last node in our list. Otherwise the list
- // is inconsistent.
- debug_assert_eq!(self.tail, Some(node.into()));
- self.tail = node.prev;
- }
- Some(mut next) => {
- debug_assert_eq!(next.as_mut().prev, Some(node.into()));
- next.as_mut().prev = node.prev;
- }
- }
-
- node.next = None;
- node.prev = None;
-
- true
- }
-
- /// Drains the list iby calling a callback on each list node
- ///
- /// The method does not return an iterator since stopping or deferring
- /// draining the list is not permitted. If the method would push nodes to
- /// an iterator we could not guarantee that the nodes do not get utilized
- /// after having been removed from the list anymore.
- pub fn drain<F>(&mut self, mut func: F)
- where
- F: FnMut(&mut ListNode<T>),
- {
- let mut current = self.head;
- self.head = None;
- self.tail = None;
-
- while let Some(mut node) = current {
- // Safety: The nodes have not been removed from the list yet and must
- // therefore contain valid data. The nodes can also not be added to
- // the list again during iteration, since the list is mutably borrowed.
- unsafe {
- let node_ref = node.as_mut();
- current = node_ref.next;
-
- node_ref.next = None;
- node_ref.prev = None;
-
- // Note: We do not reset the pointers from the next element in the
- // list to the current one since we will iterate over the whole
- // list anyway, and therefore clean up all pointers.
-
- func(node_ref);
- }
- }
- }
-
- /// Drains the list in reverse order by calling a callback on each list node
- ///
- /// The method does not return an iterator since stopping or deferring
- /// draining the list is not permitted. If the method would push nodes to
- /// an iterator we could not guarantee that the nodes do not get utilized
- /// after having been removed from the list anymore.
- pub fn reverse_drain<F>(&mut self, mut func: F)
- where
- F: FnMut(&mut ListNode<T>),
- {
- let mut current = self.tail;
- self.head = None;
- self.tail = None;
-
- while let Some(mut node) = current {
- // Safety: The nodes have not been removed from the list yet and must
- // therefore contain valid data. The nodes can also not be added to
- // the list again during iteration, since the list is mutably borrowed.
- unsafe {
- let node_ref = node.as_mut();
- current = node_ref.prev;
-
- node_ref.next = None;
- node_ref.prev = None;
-
- // Note: We do not reset the pointers from the next element in the
- // list to the current one since we will iterate over the whole
- // list anyway, and therefore clean up all pointers.
-
- func(node_ref);
- }
- }
- }
-}
-
-#[cfg(all(test, feature = "std"))] // Tests make use of Vec at the moment
-mod tests {
- use super::*;
-
- fn collect_list<T: Copy>(mut list: LinkedList<T>) -> Vec<T> {
- let mut result = Vec::new();
- list.drain(|node| {
- result.push(**node);
- });
- result
- }
-
- fn collect_reverse_list<T: Copy>(mut list: LinkedList<T>) -> Vec<T> {
- let mut result = Vec::new();
- list.reverse_drain(|node| {
- result.push(**node);
- });
- result
- }
-
- unsafe fn add_nodes(list: &mut LinkedList<i32>, nodes: &mut [&mut ListNode<i32>]) {
- for node in nodes.iter_mut() {
- list.add_front(node);
- }
- }
-
- unsafe fn assert_clean<T>(node: &mut ListNode<T>) {
- assert!(node.next.is_none());
- assert!(node.prev.is_none());
- }
-
- #[test]
- fn insert_and_iterate() {
- unsafe {
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- let mut setup = |list: &mut LinkedList<i32>| {
- assert_eq!(true, list.is_empty());
- list.add_front(&mut c);
- assert_eq!(31, **list.peek_first().unwrap());
- assert_eq!(false, list.is_empty());
- list.add_front(&mut b);
- assert_eq!(7, **list.peek_first().unwrap());
- list.add_front(&mut a);
- assert_eq!(5, **list.peek_first().unwrap());
- };
-
- let mut list = LinkedList::new();
- setup(&mut list);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31].to_vec(), items);
-
- let mut list = LinkedList::new();
- setup(&mut list);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([31, 7, 5].to_vec(), items);
- }
- }
-
- #[test]
- fn add_sorted() {
- unsafe {
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
- let mut d = ListNode::new(99);
-
- let mut list = LinkedList::new();
- list.add_sorted(&mut a);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5].to_vec(), items);
-
- let mut list = LinkedList::new();
- list.add_sorted(&mut a);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]);
- list.add_sorted(&mut a);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31, 99].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]);
- list.add_sorted(&mut a);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([99, 31, 7, 5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]);
- list.add_sorted(&mut b);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31, 99].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]);
- list.add_sorted(&mut b);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([99, 31, 7, 5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]);
- list.add_sorted(&mut c);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31, 99].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]);
- list.add_sorted(&mut c);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([99, 31, 7, 5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- list.add_sorted(&mut d);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31, 99].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- list.add_sorted(&mut d);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([99, 31, 7, 5].to_vec(), items);
- }
- }
-
- #[test]
- fn drain_and_collect() {
- unsafe {
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
-
- let taken_items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7, 31].to_vec(), taken_items);
- }
- }
-
- #[test]
- fn peek_last() {
- unsafe {
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
-
- let last = list.peek_last();
- assert_eq!(31, **last.unwrap());
- list.remove_last();
-
- let last = list.peek_last();
- assert_eq!(7, **last.unwrap());
- list.remove_last();
-
- let last = list.peek_last();
- assert_eq!(5, **last.unwrap());
- list.remove_last();
-
- let last = list.peek_last();
- assert!(last.is_none());
- }
- }
-
- #[test]
- fn remove_first() {
- unsafe {
- // We iterate forward and backwards through the manipulated lists
- // to make sure pointers in both directions are still ok.
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([7, 31].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([31, 7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert!(items.is_empty());
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut a]);
- let removed = list.remove_first().unwrap();
- assert_clean(removed);
- assert!(list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert!(items.is_empty());
- }
- }
-
- #[test]
- fn remove_last() {
- unsafe {
- // We iterate forward and backwards through the manipulated lists
- // to make sure pointers in both directions are still ok.
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([7, 5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(!list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(list.is_empty());
- let items: Vec<i32> = collect_list(list);
- assert!(items.is_empty());
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut a]);
- let removed = list.remove_last().unwrap();
- assert_clean(removed);
- assert!(list.is_empty());
- let items: Vec<i32> = collect_reverse_list(list);
- assert!(items.is_empty());
- }
- }
-
- #[test]
- fn remove_by_address() {
- unsafe {
- let mut a = ListNode::new(5);
- let mut b = ListNode::new(7);
- let mut c = ListNode::new(31);
-
- {
- // Remove first
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut a));
- assert_clean((&mut a).into());
- // a should be no longer there and can't be removed twice
- assert_eq!(false, list.remove(&mut a));
- assert_eq!(Some((&mut b).into()), list.head);
- assert_eq!(Some((&mut c).into()), b.next);
- assert_eq!(Some((&mut b).into()), c.prev);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([7, 31].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut a));
- assert_clean((&mut a).into());
- // a should be no longer there and can't be removed twice
- assert_eq!(false, list.remove(&mut a));
- assert_eq!(Some((&mut c).into()), b.next);
- assert_eq!(Some((&mut b).into()), c.prev);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([31, 7].to_vec(), items);
- }
-
- {
- // Remove middle
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut b));
- assert_clean((&mut b).into());
- assert_eq!(Some((&mut c).into()), a.next);
- assert_eq!(Some((&mut a).into()), c.prev);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 31].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut b));
- assert_clean((&mut b).into());
- assert_eq!(Some((&mut c).into()), a.next);
- assert_eq!(Some((&mut a).into()), c.prev);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([31, 5].to_vec(), items);
- }
-
- {
- // Remove last
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut c));
- assert_clean((&mut c).into());
- assert!(b.next.is_none());
- assert_eq!(Some((&mut b).into()), list.tail);
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5, 7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]);
- assert_eq!(true, list.remove(&mut c));
- assert_clean((&mut c).into());
- assert!(b.next.is_none());
- assert_eq!(Some((&mut b).into()), list.tail);
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([7, 5].to_vec(), items);
- }
-
- {
- // Remove first of two
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- assert_eq!(true, list.remove(&mut a));
- assert_clean((&mut a).into());
- // a should be no longer there and can't be removed twice
- assert_eq!(false, list.remove(&mut a));
- assert_eq!(Some((&mut b).into()), list.head);
- assert_eq!(Some((&mut b).into()), list.tail);
- assert!(b.next.is_none());
- assert!(b.prev.is_none());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([7].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- assert_eq!(true, list.remove(&mut a));
- assert_clean((&mut a).into());
- // a should be no longer there and can't be removed twice
- assert_eq!(false, list.remove(&mut a));
- assert_eq!(Some((&mut b).into()), list.head);
- assert_eq!(Some((&mut b).into()), list.tail);
- assert!(b.next.is_none());
- assert!(b.prev.is_none());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([7].to_vec(), items);
- }
-
- {
- // Remove last of two
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- assert_eq!(true, list.remove(&mut b));
- assert_clean((&mut b).into());
- assert_eq!(Some((&mut a).into()), list.head);
- assert_eq!(Some((&mut a).into()), list.tail);
- assert!(a.next.is_none());
- assert!(a.prev.is_none());
- let items: Vec<i32> = collect_list(list);
- assert_eq!([5].to_vec(), items);
-
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut b, &mut a]);
- assert_eq!(true, list.remove(&mut b));
- assert_clean((&mut b).into());
- assert_eq!(Some((&mut a).into()), list.head);
- assert_eq!(Some((&mut a).into()), list.tail);
- assert!(a.next.is_none());
- assert!(a.prev.is_none());
- let items: Vec<i32> = collect_reverse_list(list);
- assert_eq!([5].to_vec(), items);
- }
-
- {
- // Remove last item
- let mut list = LinkedList::new();
- add_nodes(&mut list, &mut [&mut a]);
- assert_eq!(true, list.remove(&mut a));
- assert_clean((&mut a).into());
- assert!(list.head.is_none());
- assert!(list.tail.is_none());
- let items: Vec<i32> = collect_list(list);
- assert!(items.is_empty());
- }
-
- {
- // Remove missing
- let mut list = LinkedList::new();
- list.add_front(&mut b);
- list.add_front(&mut a);
- assert_eq!(false, list.remove(&mut c));
- }
- }
- }
-}
diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs
index aa3ce77..4681276 100644
--- a/src/util/linked_list.rs
+++ b/src/util/linked_list.rs
@@ -1,3 +1,5 @@
+#![cfg_attr(not(feature = "full"), allow(dead_code))]
+
//! An intrusive double linked list of data
//!
//! The data structure supports tracking pinned nodes. Most of the data
@@ -5,6 +7,7 @@
//! specified node is actually contained by the list.
use core::fmt;
+use core::marker::PhantomData;
use core::mem::ManuallyDrop;
use core::ptr::NonNull;
@@ -12,16 +15,19 @@ use core::ptr::NonNull;
///
/// Currently, the list is not emptied on drop. It is the caller's
/// responsibility to ensure the list is empty before dropping it.
-pub(crate) struct LinkedList<T: Link> {
+pub(crate) struct LinkedList<L, T> {
/// Linked list head
- head: Option<NonNull<T::Target>>,
+ head: Option<NonNull<T>>,
/// Linked list tail
- tail: Option<NonNull<T::Target>>,
+ tail: Option<NonNull<T>>,
+
+ /// Node type marker.
+ _marker: PhantomData<*const L>,
}
-unsafe impl<T: Link> Send for LinkedList<T> where T::Target: Send {}
-unsafe impl<T: Link> Sync for LinkedList<T> where T::Target: Sync {}
+unsafe impl<L: Link> Send for LinkedList<L, L::Target> where L::Target: Send {}
+unsafe impl<L: Link> Sync for LinkedList<L, L::Target> where L::Target: Sync {}
/// Defines how a type is tracked within a linked list.
///
@@ -66,27 +72,30 @@ unsafe impl<T: Sync> Sync for Pointers<T> {}
// ===== impl LinkedList =====
-impl<T: Link> LinkedList<T> {
- /// Creates an empty linked list
- pub(crate) fn new() -> LinkedList<T> {
+impl<L, T> LinkedList<L, T> {
+ /// Creates an empty linked list.
+ pub(crate) const fn new() -> LinkedList<L, T> {
LinkedList {
head: None,
tail: None,
+ _marker: PhantomData,
}
}
+}
+impl<L: Link> LinkedList<L, L::Target> {
/// Adds an element first in the list.
- pub(crate) fn push_front(&mut self, val: T::Handle) {
+ pub(crate) fn push_front(&mut self, val: L::Handle) {
// The value should not be dropped, it is being inserted into the list
let val = ManuallyDrop::new(val);
- let ptr = T::as_raw(&*val);
+ let ptr = L::as_raw(&*val);
assert_ne!(self.head, Some(ptr));
unsafe {
- T::pointers(ptr).as_mut().next = self.head;
- T::pointers(ptr).as_mut().prev = None;
+ L::pointers(ptr).as_mut().next = self.head;
+ L::pointers(ptr).as_mut().prev = None;
if let Some(head) = self.head {
- T::pointers(head).as_mut().prev = Some(ptr);
+ L::pointers(head).as_mut().prev = Some(ptr);
}
self.head = Some(ptr);
@@ -99,21 +108,21 @@ impl<T: Link> LinkedList<T> {
/// Removes the last element from a list and returns it, or None if it is
/// empty.
- pub(crate) fn pop_back(&mut self) -> Option<T::Handle> {
+ pub(crate) fn pop_back(&mut self) -> Option<L::Handle> {
unsafe {
let last = self.tail?;
- self.tail = T::pointers(last).as_ref().prev;
+ self.tail = L::pointers(last).as_ref().prev;
- if let Some(prev) = T::pointers(last).as_ref().prev {
- T::pointers(prev).as_mut().next = None;
+ if let Some(prev) = L::pointers(last).as_ref().prev {
+ L::pointers(prev).as_mut().next = None;
} else {
self.head = None
}
- T::pointers(last).as_mut().prev = None;
- T::pointers(last).as_mut().next = None;
+ L::pointers(last).as_mut().prev = None;
+ L::pointers(last).as_mut().next = None;
- Some(T::from_raw(last))
+ Some(L::from_raw(last))
}
}
@@ -133,38 +142,38 @@ impl<T: Link> LinkedList<T> {
///
/// The caller **must** ensure that `node` is currently contained by
/// `self` or not contained by any other list.
- pub(crate) unsafe fn remove(&mut self, node: NonNull<T::Target>) -> Option<T::Handle> {
- if let Some(prev) = T::pointers(node).as_ref().prev {
- debug_assert_eq!(T::pointers(prev).as_ref().next, Some(node));
- T::pointers(prev).as_mut().next = T::pointers(node).as_ref().next;
+ pub(crate) unsafe fn remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle> {
+ if let Some(prev) = L::pointers(node).as_ref().prev {
+ debug_assert_eq!(L::pointers(prev).as_ref().next, Some(node));
+ L::pointers(prev).as_mut().next = L::pointers(node).as_ref().next;
} else {
if self.head != Some(node) {
return None;
}
- self.head = T::pointers(node).as_ref().next;
+ self.head = L::pointers(node).as_ref().next;
}
- if let Some(next) = T::pointers(node).as_ref().next {
- debug_assert_eq!(T::pointers(next).as_ref().prev, Some(node));
- T::pointers(next).as_mut().prev = T::pointers(node).as_ref().prev;
+ if let Some(next) = L::pointers(node).as_ref().next {
+ debug_assert_eq!(L::pointers(next).as_ref().prev, Some(node));
+ L::pointers(next).as_mut().prev = L::pointers(node).as_ref().prev;
} else {
// This might be the last item in the list
if self.tail != Some(node) {
return None;
}
- self.tail = T::pointers(node).as_ref().prev;
+ self.tail = L::pointers(node).as_ref().prev;
}
- T::pointers(node).as_mut().next = None;
- T::pointers(node).as_mut().prev = None;
+ L::pointers(node).as_mut().next = None;
+ L::pointers(node).as_mut().prev = None;
- Some(T::from_raw(node))
+ Some(L::from_raw(node))
}
}
-impl<T: Link> fmt::Debug for LinkedList<T> {
+impl<L: Link> fmt::Debug for LinkedList<L, L::Target> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LinkedList")
.field("head", &self.head)
@@ -173,27 +182,35 @@ impl<T: Link> fmt::Debug for LinkedList<T> {
}
}
-cfg_sync! {
- impl<T: Link> LinkedList<T> {
- pub(crate) fn last(&self) -> Option<&T::Target> {
- let tail = self.tail.as_ref()?;
- unsafe {
- Some(&*tail.as_ptr())
- }
- }
+#[cfg(any(
+ feature = "fs",
+ all(unix, feature = "process"),
+ feature = "signal",
+ feature = "sync",
+))]
+impl<L: Link> LinkedList<L, L::Target> {
+ pub(crate) fn last(&self) -> Option<&L::Target> {
+ let tail = self.tail.as_ref()?;
+ unsafe { Some(&*tail.as_ptr()) }
+ }
+}
+
+impl<L: Link> Default for LinkedList<L, L::Target> {
+ fn default() -> Self {
+ Self::new()
}
}
// ===== impl Iter =====
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
pub(crate) struct Iter<'a, T: Link> {
curr: Option<NonNull<T::Target>>,
_p: core::marker::PhantomData<&'a T>,
}
- impl<T: Link> LinkedList<T> {
- pub(crate) fn iter(&self) -> Iter<'_, T> {
+ impl<L: Link> LinkedList<L, L::Target> {
+ pub(crate) fn iter(&self) -> Iter<'_, L> {
Iter {
curr: self.head,
_p: core::marker::PhantomData,
@@ -215,6 +232,52 @@ cfg_rt_threaded! {
}
}
+// ===== impl DrainFilter =====
+
+cfg_io_readiness! {
+ pub(crate) struct DrainFilter<'a, T: Link, F> {
+ list: &'a mut LinkedList<T, T::Target>,
+ filter: F,
+ curr: Option<NonNull<T::Target>>,
+ }
+
+ impl<T: Link> LinkedList<T, T::Target> {
+ pub(crate) fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, F>
+ where
+ F: FnMut(&mut T::Target) -> bool,
+ {
+ let curr = self.head;
+ DrainFilter {
+ curr,
+ filter,
+ list: self,
+ }
+ }
+ }
+
+ impl<'a, T, F> Iterator for DrainFilter<'a, T, F>
+ where
+ T: Link,
+ F: FnMut(&mut T::Target) -> bool,
+ {
+ type Item = T::Handle;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ while let Some(curr) = self.curr {
+ // safety: the pointer references data contained by the list
+ self.curr = unsafe { T::pointers(curr).as_ref() }.next;
+
+ // safety: the value is still owned by the linked list.
+ if (self.filter)(unsafe { &mut *curr.as_ptr() }) {
+ return unsafe { self.list.remove(curr) };
+ }
+ }
+
+ None
+ }
+ }
+}
+
// ===== impl Pointers =====
impl<T> Pointers<T> {
@@ -277,7 +340,7 @@ mod tests {
r.as_ref().get_ref().into()
}
- fn collect_list(list: &mut LinkedList<&'_ Entry>) -> Vec<i32> {
+ fn collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec<i32> {
let mut ret = vec![];
while let Some(entry) = list.pop_back() {
@@ -287,7 +350,10 @@ mod tests {
ret
}
- fn push_all<'a>(list: &mut LinkedList<&'a Entry>, entries: &[Pin<&'a Entry>]) {
+ fn push_all<'a>(
+ list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>,
+ entries: &[Pin<&'a Entry>],
+ ) {
for entry in entries.iter() {
list.push_front(*entry);
}
@@ -308,6 +374,11 @@ mod tests {
}
#[test]
+ fn const_new() {
+ const _: LinkedList<&Entry, <&Entry as Link>::Target> = LinkedList::new();
+ }
+
+ #[test]
fn push_and_drain() {
let a = entry(5);
let b = entry(7);
@@ -332,7 +403,7 @@ mod tests {
let a = entry(5);
let b = entry(7);
- let mut list = LinkedList::<&Entry>::new();
+ let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
list.push_front(a.as_ref());
@@ -489,7 +560,7 @@ mod tests {
unsafe {
// Remove missing
- let mut list = LinkedList::<&Entry>::new();
+ let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
list.push_front(b.as_ref());
list.push_front(a.as_ref());
@@ -503,7 +574,7 @@ mod tests {
let a = entry(5);
let b = entry(7);
- let mut list = LinkedList::<&Entry>::new();
+ let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
assert_eq!(0, list.iter().count());
@@ -543,7 +614,7 @@ mod tests {
})
.collect::<Vec<_>>();
- let mut ll = LinkedList::<&Entry>::new();
+ let mut ll = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
let mut reference = VecDeque::new();
let entries: Vec<_> = (0..ops.len()).map(|i| entry(i as i32)).collect();
diff --git a/src/util/mod.rs b/src/util/mod.rs
index 6dda08c..b2043dd 100644
--- a/src/util/mod.rs
+++ b/src/util/mod.rs
@@ -3,16 +3,26 @@ cfg_io_driver! {
pub(crate) mod slab;
}
-#[cfg(any(feature = "sync", feature = "rt-core"))]
+#[cfg(any(
+ feature = "fs",
+ feature = "net",
+ feature = "process",
+ feature = "rt",
+ feature = "sync",
+ feature = "signal",
+))]
pub(crate) mod linked_list;
-#[cfg(any(feature = "rt-threaded", feature = "macros", feature = "stream"))]
+#[cfg(any(feature = "rt-multi-thread", feature = "macros", feature = "stream"))]
mod rand;
-mod wake;
-pub(crate) use wake::{waker_ref, Wake};
+cfg_rt! {
+ mod wake;
+ pub(crate) use wake::WakerRef;
+ pub(crate) use wake::{waker_ref, Wake};
+}
-cfg_rt_threaded! {
+cfg_rt_multi_thread! {
pub(crate) use rand::FastRand;
mod try_lock;
@@ -24,5 +34,3 @@ pub(crate) mod trace;
#[cfg(any(feature = "macros", feature = "stream"))]
#[cfg_attr(not(feature = "macros"), allow(unreachable_pub))]
pub use rand::thread_rng_n;
-
-pub(crate) mod intrusive_double_linked_list;
diff --git a/src/util/slab.rs b/src/util/slab.rs
new file mode 100644
index 0000000..efc72e1
--- /dev/null
+++ b/src/util/slab.rs
@@ -0,0 +1,841 @@
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
+use crate::loom::cell::UnsafeCell;
+use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
+use crate::loom::sync::{Arc, Mutex};
+use crate::util::bit;
+use std::fmt;
+use std::mem;
+use std::ops;
+use std::ptr;
+use std::sync::atomic::Ordering::Relaxed;
+
+/// Amortized allocation for homogeneous data types.
+///
+/// The slab pre-allocates chunks of memory to store values. It uses a similar
+/// growing strategy as `Vec`. When new capacity is needed, the slab grows by
+/// 2x.
+///
+/// # Pages
+///
+/// Unlike `Vec`, growing does not require moving existing elements. Instead of
+/// being a continuous chunk of memory for all elements, `Slab` is an array of
+/// arrays. The top-level array is an array of pages. Each page is 2x bigger
+/// than the previous one. When the slab grows, a new page is allocated.
+///
+/// Pages are lazily initialized.
+///
+/// # Allocating
+///
+/// When allocating an object, first previously used slots are reused. If no
+/// previously used slot is available, a new slot is initialized in an existing
+/// page. If all pages are full, then a new page is allocated.
+///
+/// When an allocated object is released, it is pushed into it's page's free
+/// list. Allocating scans all pages for a free slot.
+///
+/// # Indexing
+///
+/// The slab is able to index values using an address. Even when the indexed
+/// object has been released, it is still safe to index. This is a key ability
+/// for using the slab with the I/O driver. Addresses are registered with the
+/// OS's selector and I/O resources can be released without synchronizing with
+/// the OS.
+///
+/// # Compaction
+///
+/// `Slab::compact` will release pages that have been allocated but are no
+/// longer used. This is done by scanning the pages and finding pages with no
+/// allocated objects. These pages are then freed.
+///
+/// # Synchronization
+///
+/// The `Slab` structure is able to provide (mostly) unsynchronized reads to
+/// values stored in the slab. Insertions and removals are synchronized. Reading
+/// objects via `Ref` is fully unsynchronized. Indexing objects uses amortized
+/// synchronization.
+///
+pub(crate) struct Slab<T> {
+ /// Array of pages. Each page is synchronized.
+ pages: [Arc<Page<T>>; NUM_PAGES],
+
+ /// Caches the array pointer & number of initialized slots.
+ cached: [CachedPage<T>; NUM_PAGES],
+}
+
+/// Allocate values in the associated slab.
+pub(crate) struct Allocator<T> {
+ /// Pages in the slab. The first page has a capacity of 16 elements. Each
+ /// following page has double the capacity of the previous page.
+ ///
+ /// Each returned `Ref` holds a reference count to this `Arc`.
+ pages: [Arc<Page<T>>; NUM_PAGES],
+}
+
+/// References a slot in the slab. Indexing a slot using an `Address` is memory
+/// safe even if the slot has been released or the page has been deallocated.
+/// However, it is not guaranteed that the slot has not been reused and is now
+/// represents a different value.
+///
+/// The I/O driver uses a counter to track the slot's generation. Once accessing
+/// the slot, the generations are compared. If they match, the value matches the
+/// address.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub(crate) struct Address(usize);
+
+/// An entry in the slab.
+pub(crate) trait Entry: Default {
+ /// Reset the entry's value and track the generation.
+ fn reset(&self);
+}
+
+/// A reference to a value stored in the slab
+pub(crate) struct Ref<T> {
+ value: *const Value<T>,
+}
+
+/// Maximum number of pages a slab can contain.
+const NUM_PAGES: usize = 19;
+
+/// Minimum number of slots a page can contain.
+const PAGE_INITIAL_SIZE: usize = 32;
+const PAGE_INDEX_SHIFT: u32 = PAGE_INITIAL_SIZE.trailing_zeros() + 1;
+
+/// A page in the slab
+struct Page<T> {
+ /// Slots
+ slots: Mutex<Slots<T>>,
+
+ // Number of slots currently being used. This is not guaranteed to be up to
+ // date and should only be used as a hint.
+ used: AtomicUsize,
+
+ // Set to `true` when the page has been allocated.
+ allocated: AtomicBool,
+
+ // The number of slots the page can hold.
+ len: usize,
+
+ // Length of all previous pages combined
+ prev_len: usize,
+}
+
+struct CachedPage<T> {
+ /// Pointer to the page's slots.
+ slots: *const Slot<T>,
+
+ /// Number of initialized slots.
+ init: usize,
+}
+
+/// Page state
+struct Slots<T> {
+ /// Slots
+ slots: Vec<Slot<T>>,
+
+ head: usize,
+
+ /// Number of slots currently in use.
+ used: usize,
+}
+
+unsafe impl<T: Sync> Sync for Page<T> {}
+unsafe impl<T: Sync> Send for Page<T> {}
+unsafe impl<T: Sync> Sync for CachedPage<T> {}
+unsafe impl<T: Sync> Send for CachedPage<T> {}
+unsafe impl<T: Sync> Sync for Ref<T> {}
+unsafe impl<T: Sync> Send for Ref<T> {}
+
+/// A slot in the slab. Contains slot-specific metadata.
+///
+/// `#[repr(C)]` guarantees that the struct starts w/ `value`. We use pointer
+/// math to map a value pointer to an index in the page.
+#[repr(C)]
+struct Slot<T> {
+ /// Pointed to by `Ref`.
+ value: UnsafeCell<Value<T>>,
+
+ /// Next entry in the free list.
+ next: u32,
+}
+
+/// Value paired with a reference to the page
+struct Value<T> {
+ /// Value stored in the value
+ value: T,
+
+ /// Pointer to the page containing the slot.
+ ///
+ /// A raw pointer is used as this creates a ref cycle.
+ page: *const Page<T>,
+}
+
+impl<T> Slab<T> {
+ /// Create a new, empty, slab
+ pub(crate) fn new() -> Slab<T> {
+ // Initializing arrays is a bit annoying. Instead of manually writing
+ // out an array and every single entry, `Default::default()` is used to
+ // initialize the array, then the array is iterated and each value is
+ // initialized.
+ let mut slab = Slab {
+ pages: Default::default(),
+ cached: Default::default(),
+ };
+
+ let mut len = PAGE_INITIAL_SIZE;
+ let mut prev_len: usize = 0;
+
+ for page in &mut slab.pages {
+ let page = Arc::get_mut(page).unwrap();
+ page.len = len;
+ page.prev_len = prev_len;
+ len *= 2;
+ prev_len += page.len;
+
+ // Ensure we don't exceed the max address space.
+ debug_assert!(
+ page.len - 1 + page.prev_len < (1 << 24),
+ "max = {:b}",
+ page.len - 1 + page.prev_len
+ );
+ }
+
+ slab
+ }
+
+ /// Returns a new `Allocator`.
+ ///
+ /// The `Allocator` supports concurrent allocation of objects.
+ pub(crate) fn allocator(&self) -> Allocator<T> {
+ Allocator {
+ pages: self.pages.clone(),
+ }
+ }
+
+ /// Returns a reference to the value stored at the given address.
+ ///
+ /// `&mut self` is used as the call may update internal cached state.
+ pub(crate) fn get(&mut self, addr: Address) -> Option<&T> {
+ let page_idx = addr.page();
+ let slot_idx = self.pages[page_idx].slot(addr);
+
+ // If the address references a slot that was last seen as uninitialized,
+ // the `CachedPage` is updated. This requires acquiring the page lock
+ // and updating the slot pointer and initialized offset.
+ if self.cached[page_idx].init <= slot_idx {
+ self.cached[page_idx].refresh(&self.pages[page_idx]);
+ }
+
+ // If the address **still** references an uninitialized slot, then the
+ // address is invalid and `None` is returned.
+ if self.cached[page_idx].init <= slot_idx {
+ return None;
+ }
+
+ // Get a reference to the value. The lifetime of the returned reference
+ // is bound to `&self`. The only way to invalidate the underlying memory
+ // is to call `compact()`. The lifetimes prevent calling `compact()`
+ // while references to values are outstanding.
+ //
+ // The referenced data is never mutated. Only `&self` references are
+ // used and the data is `Sync`.
+ Some(self.cached[page_idx].get(slot_idx))
+ }
+
+ /// Calls the given function with a reference to each slot in the slab. The
+ /// slot may not be in-use.
+ ///
+ /// This is used by the I/O driver during the shutdown process to notify
+ /// each pending task.
+ pub(crate) fn for_each(&mut self, mut f: impl FnMut(&T)) {
+ for page_idx in 0..self.pages.len() {
+ // It is required to avoid holding the lock when calling the
+ // provided function. The function may attempt to acquire the lock
+ // itself. If we hold the lock here while calling `f`, a deadlock
+ // situation is possible.
+ //
+ // Instead of iterating the slots directly in `page`, which would
+ // require holding the lock, the cache is updated and the slots are
+ // iterated from the cache.
+ self.cached[page_idx].refresh(&self.pages[page_idx]);
+
+ for slot_idx in 0..self.cached[page_idx].init {
+ f(self.cached[page_idx].get(slot_idx));
+ }
+ }
+ }
+
+ // Release memory back to the allocator.
+ //
+ // If pages are empty, the underlying memory is released back to the
+ // allocator.
+ pub(crate) fn compact(&mut self) {
+ // Iterate each page except the very first one. The very first page is
+ // never freed.
+ for (idx, page) in self.pages.iter().enumerate().skip(1) {
+ if page.used.load(Relaxed) != 0 || !page.allocated.load(Relaxed) {
+ // If the page has slots in use or the memory has not been
+ // allocated then it cannot be compacted.
+ continue;
+ }
+
+ let mut slots = match page.slots.try_lock() {
+ Some(slots) => slots,
+ // If the lock cannot be acquired due to being held by another
+ // thread, don't try to compact the page.
+ _ => continue,
+ };
+
+ if slots.used > 0 || slots.slots.capacity() == 0 {
+ // The page is in use or it has not yet been allocated. Either
+ // way, there is no more work to do.
+ continue;
+ }
+
+ page.allocated.store(false, Relaxed);
+
+ // Remove the slots vector from the page. This is done so that the
+ // freeing process is done outside of the lock's critical section.
+ let vec = mem::replace(&mut slots.slots, vec![]);
+ slots.head = 0;
+
+ // Drop the lock so we can drop the vector outside the lock below.
+ drop(slots);
+
+ debug_assert!(
+ self.cached[idx].slots.is_null() || self.cached[idx].slots == vec.as_ptr(),
+ "cached = {:?}; actual = {:?}",
+ self.cached[idx].slots,
+ vec.as_ptr(),
+ );
+
+ // Clear cache
+ self.cached[idx].slots = ptr::null();
+ self.cached[idx].init = 0;
+
+ drop(vec);
+ }
+ }
+}
+
+impl<T> fmt::Debug for Slab<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ debug(fmt, "Slab", &self.pages[..])
+ }
+}
+
+impl<T: Entry> Allocator<T> {
+ /// Allocate a new entry and return a handle to the entry.
+ ///
+ /// Scans pages from smallest to biggest, stopping when a slot is found.
+ /// Pages are allocated if necessary.
+ ///
+ /// Returns `None` if the slab is full.
+ pub(crate) fn allocate(&self) -> Option<(Address, Ref<T>)> {
+ // Find the first available slot.
+ for page in &self.pages[..] {
+ if let Some((addr, val)) = Page::allocate(page) {
+ return Some((addr, val));
+ }
+ }
+
+ None
+ }
+}
+
+impl<T> fmt::Debug for Allocator<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ debug(fmt, "slab::Allocator", &self.pages[..])
+ }
+}
+
+impl<T> ops::Deref for Ref<T> {
+ type Target = T;
+
+ fn deref(&self) -> &T {
+ // Safety: `&mut` is never handed out to the underlying value. The page
+ // is not freed until all `Ref` values are dropped.
+ unsafe { &(*self.value).value }
+ }
+}
+
+impl<T> Drop for Ref<T> {
+ fn drop(&mut self) {
+ // Safety: `&mut` is never handed out to the underlying value. The page
+ // is not freed until all `Ref` values are dropped.
+ let _ = unsafe { (*self.value).release() };
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for Ref<T> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ (**self).fmt(fmt)
+ }
+}
+
+impl<T: Entry> Page<T> {
+ // Allocates an object, returns the ref and address.
+ //
+ // `self: &Arc<Page<T>>` is avoided here as this would not work with the
+ // loom `Arc`.
+ fn allocate(me: &Arc<Page<T>>) -> Option<(Address, Ref<T>)> {
+ // Before acquiring the lock, use the `used` hint.
+ if me.used.load(Relaxed) == me.len {
+ return None;
+ }
+
+ // Allocating objects requires synchronization
+ let mut locked = me.slots.lock();
+
+ if locked.head < locked.slots.len() {
+ // Re-use an already initialized slot.
+ //
+ // Help out the borrow checker
+ let locked = &mut *locked;
+
+ // Get the index of the slot at the head of the free stack. This is
+ // the slot that will be reused.
+ let idx = locked.head;
+ let slot = &locked.slots[idx];
+
+ // Update the free stack head to point to the next slot.
+ locked.head = slot.next as usize;
+
+ // Increment the number of used slots
+ locked.used += 1;
+ me.used.store(locked.used, Relaxed);
+
+ // Reset the slot
+ slot.value.with(|ptr| unsafe { (*ptr).value.reset() });
+
+ // Return a reference to the slot
+ Some((me.addr(idx), slot.gen_ref(me)))
+ } else if me.len == locked.slots.len() {
+ // The page is full
+ None
+ } else {
+ // No initialized slots are available, but the page has more
+ // capacity. Initialize a new slot.
+ let idx = locked.slots.len();
+
+ if idx == 0 {
+ // The page has not yet been allocated. Allocate the storage for
+ // all page slots.
+ locked.slots.reserve_exact(me.len);
+ }
+
+ // Initialize a new slot
+ locked.slots.push(Slot {
+ value: UnsafeCell::new(Value {
+ value: Default::default(),
+ page: &**me as *const _,
+ }),
+ next: 0,
+ });
+
+ // Increment the head to indicate the free stack is empty
+ locked.head += 1;
+
+ // Increment the number of used slots
+ locked.used += 1;
+ me.used.store(locked.used, Relaxed);
+ me.allocated.store(true, Relaxed);
+
+ debug_assert_eq!(locked.slots.len(), locked.head);
+
+ Some((me.addr(idx), locked.slots[idx].gen_ref(me)))
+ }
+ }
+}
+
+impl<T> Page<T> {
+ /// Returns the slot index within the current page referenced by the given
+ /// address.
+ fn slot(&self, addr: Address) -> usize {
+ addr.0 - self.prev_len
+ }
+
+ /// Returns the address for the given slot
+ fn addr(&self, slot: usize) -> Address {
+ Address(slot + self.prev_len)
+ }
+}
+
+impl<T> Default for Page<T> {
+ fn default() -> Page<T> {
+ Page {
+ used: AtomicUsize::new(0),
+ allocated: AtomicBool::new(false),
+ slots: Mutex::new(Slots {
+ slots: Vec::new(),
+ head: 0,
+ used: 0,
+ }),
+ len: 0,
+ prev_len: 0,
+ }
+ }
+}
+
+impl<T> Page<T> {
+ /// Release a slot into the page's free list
+ fn release(&self, value: *const Value<T>) {
+ let mut locked = self.slots.lock();
+
+ let idx = locked.index_for(value);
+ locked.slots[idx].next = locked.head as u32;
+ locked.head = idx;
+ locked.used -= 1;
+
+ self.used.store(locked.used, Relaxed);
+ }
+}
+
+impl<T> CachedPage<T> {
+ /// Refresh the cache
+ fn refresh(&mut self, page: &Page<T>) {
+ let slots = page.slots.lock();
+
+ if !slots.slots.is_empty() {
+ self.slots = slots.slots.as_ptr();
+ self.init = slots.slots.len();
+ }
+ }
+
+ // Get a value by index
+ fn get(&self, idx: usize) -> &T {
+ assert!(idx < self.init);
+
+ // Safety: Pages are allocated concurrently, but are only ever
+ // **deallocated** by `Slab`. `Slab` will always have a more
+ // conservative view on the state of the slot array. Once `CachedPage`
+ // sees a slot pointer and initialized offset, it will remain valid
+ // until `compact()` is called. The `compact()` function also updates
+ // `CachedPage`.
+ unsafe {
+ let slot = self.slots.add(idx);
+ let value = slot as *const Value<T>;
+
+ &(*value).value
+ }
+ }
+}
+
+impl<T> Default for CachedPage<T> {
+ fn default() -> CachedPage<T> {
+ CachedPage {
+ slots: ptr::null(),
+ init: 0,
+ }
+ }
+}
+
+impl<T> Slots<T> {
+ /// Maps a slot pointer to an offset within the current page.
+ ///
+ /// The pointer math removes the `usize` index from the `Ref` struct,
+ /// shrinking the struct to a single pointer size. The contents of the
+ /// function is safe, the resulting `usize` is bounds checked before being
+ /// used.
+ ///
+ /// # Panics
+ ///
+ /// panics if the provided slot pointer is not contained by the page.
+ fn index_for(&self, slot: *const Value<T>) -> usize {
+ use std::mem;
+
+ let base = &self.slots[0] as *const _ as usize;
+
+ assert!(base != 0, "page is unallocated");
+
+ let slot = slot as usize;
+ let width = mem::size_of::<Slot<T>>();
+
+ assert!(slot >= base, "unexpected pointer");
+
+ let idx = (slot - base) / width;
+ assert!(idx < self.slots.len() as usize);
+
+ idx
+ }
+}
+
+impl<T: Entry> Slot<T> {
+ /// Generates a `Ref` for the slot. This involves bumping the page's ref count.
+ fn gen_ref(&self, page: &Arc<Page<T>>) -> Ref<T> {
+ // The ref holds a ref on the page. The `Arc` is forgotten here and is
+ // resurrected in `release` when the `Ref` is dropped. By avoiding to
+ // hold on to an explicit `Arc` value, the struct size of `Ref` is
+ // reduced.
+ mem::forget(page.clone());
+ let slot = self as *const Slot<T>;
+ let value = slot as *const Value<T>;
+
+ Ref { value }
+ }
+}
+
+impl<T> Value<T> {
+ // Release the slot, returning the `Arc<Page<T>>` logically owned by the ref.
+ fn release(&self) -> Arc<Page<T>> {
+ // Safety: called by `Ref`, which owns an `Arc<Page<T>>` instance.
+ let page = unsafe { Arc::from_raw(self.page) };
+ page.release(self as *const _);
+ page
+ }
+}
+
+impl Address {
+ fn page(self) -> usize {
+ // Since every page is twice as large as the previous page, and all page
+ // sizes are powers of two, we can determine the page index that
+ // contains a given address by shifting the address down by the smallest
+ // page size and looking at how many twos places necessary to represent
+ // that number, telling us what power of two page size it fits inside
+ // of. We can determine the number of twos places by counting the number
+ // of leading zeros (unused twos places) in the number's binary
+ // representation, and subtracting that count from the total number of
+ // bits in a word.
+ let slot_shifted = (self.0 + PAGE_INITIAL_SIZE) >> PAGE_INDEX_SHIFT;
+ (bit::pointer_width() - slot_shifted.leading_zeros()) as usize
+ }
+
+ pub(crate) const fn as_usize(self) -> usize {
+ self.0
+ }
+
+ pub(crate) fn from_usize(src: usize) -> Address {
+ Address(src)
+ }
+}
+
+fn debug<T>(fmt: &mut fmt::Formatter<'_>, name: &str, pages: &[Arc<Page<T>>]) -> fmt::Result {
+ let mut capacity = 0;
+ let mut len = 0;
+
+ for page in pages {
+ if page.allocated.load(Relaxed) {
+ capacity += page.len;
+ len += page.used.load(Relaxed);
+ }
+ }
+
+ fmt.debug_struct(name)
+ .field("len", &len)
+ .field("capacity", &capacity)
+ .finish()
+}
+
+#[cfg(all(test, not(loom)))]
+mod test {
+ use super::*;
+ use std::sync::atomic::AtomicUsize;
+ use std::sync::atomic::Ordering::SeqCst;
+
+ struct Foo {
+ cnt: AtomicUsize,
+ id: AtomicUsize,
+ }
+
+ impl Default for Foo {
+ fn default() -> Foo {
+ Foo {
+ cnt: AtomicUsize::new(0),
+ id: AtomicUsize::new(0),
+ }
+ }
+ }
+
+ impl Entry for Foo {
+ fn reset(&self) {
+ self.cnt.fetch_add(1, SeqCst);
+ }
+ }
+
+ #[test]
+ fn insert_remove() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+
+ let (addr1, foo1) = alloc.allocate().unwrap();
+ foo1.id.store(1, SeqCst);
+ assert_eq!(0, foo1.cnt.load(SeqCst));
+
+ let (addr2, foo2) = alloc.allocate().unwrap();
+ foo2.id.store(2, SeqCst);
+ assert_eq!(0, foo2.cnt.load(SeqCst));
+
+ assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst));
+ assert_eq!(2, slab.get(addr2).unwrap().id.load(SeqCst));
+
+ drop(foo1);
+
+ assert_eq!(1, slab.get(addr1).unwrap().id.load(SeqCst));
+
+ let (addr3, foo3) = alloc.allocate().unwrap();
+ assert_eq!(addr3, addr1);
+ assert_eq!(1, foo3.cnt.load(SeqCst));
+ foo3.id.store(3, SeqCst);
+ assert_eq!(3, slab.get(addr3).unwrap().id.load(SeqCst));
+
+ drop(foo2);
+ drop(foo3);
+
+ slab.compact();
+
+ // The first page is never released
+ assert!(slab.get(addr1).is_some());
+ assert!(slab.get(addr2).is_some());
+ assert!(slab.get(addr3).is_some());
+ }
+
+ #[test]
+ fn insert_many() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+ let mut entries = vec![];
+
+ for i in 0..10_000 {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(i, SeqCst);
+ entries.push((addr, val));
+ }
+
+ for (i, (addr, v)) in entries.iter().enumerate() {
+ assert_eq!(i, v.id.load(SeqCst));
+ assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+
+ entries.clear();
+
+ for i in 0..10_000 {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(10_000 - i, SeqCst);
+ entries.push((addr, val));
+ }
+
+ for (i, (addr, v)) in entries.iter().enumerate() {
+ assert_eq!(10_000 - i, v.id.load(SeqCst));
+ assert_eq!(10_000 - i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+ }
+
+ #[test]
+ fn insert_drop_reverse() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+ let mut entries = vec![];
+
+ for i in 0..10_000 {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(i, SeqCst);
+ entries.push((addr, val));
+ }
+
+ for _ in 0..10 {
+ // Drop 1000 in reverse
+ for _ in 0..1_000 {
+ entries.pop();
+ }
+
+ // Check remaining
+ for (i, (addr, v)) in entries.iter().enumerate() {
+ assert_eq!(i, v.id.load(SeqCst));
+ assert_eq!(i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+ }
+ }
+
+ #[test]
+ fn no_compaction_if_page_still_in_use() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+ let mut entries1 = vec![];
+ let mut entries2 = vec![];
+
+ for i in 0..10_000 {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(i, SeqCst);
+
+ if i % 2 == 0 {
+ entries1.push((addr, val, i));
+ } else {
+ entries2.push(val);
+ }
+ }
+
+ drop(entries2);
+
+ for (addr, _, i) in &entries1 {
+ assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+ }
+
+ #[test]
+ fn compact_all() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+ let mut entries = vec![];
+
+ for _ in 0..2 {
+ entries.clear();
+
+ for i in 0..10_000 {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(i, SeqCst);
+
+ entries.push((addr, val));
+ }
+
+ let mut addrs = vec![];
+
+ for (addr, _) in entries.drain(..) {
+ addrs.push(addr);
+ }
+
+ slab.compact();
+
+ // The first page is never freed
+ for addr in &addrs[PAGE_INITIAL_SIZE..] {
+ assert!(slab.get(*addr).is_none());
+ }
+ }
+ }
+
+ #[test]
+ fn issue_3014() {
+ let mut slab = Slab::<Foo>::new();
+ let alloc = slab.allocator();
+ let mut entries = vec![];
+
+ for _ in 0..5 {
+ entries.clear();
+
+ // Allocate a few pages + 1
+ for i in 0..(32 + 64 + 128 + 1) {
+ let (addr, val) = alloc.allocate().unwrap();
+ val.id.store(i, SeqCst);
+
+ entries.push((addr, val, i));
+ }
+
+ for (addr, val, i) in &entries {
+ assert_eq!(*i, val.id.load(SeqCst));
+ assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+
+ // Release the last entry
+ entries.pop();
+
+ // Compact
+ slab.compact();
+
+ // Check all the addresses
+
+ for (addr, val, i) in &entries {
+ assert_eq!(*i, val.id.load(SeqCst));
+ assert_eq!(*i, slab.get(*addr).unwrap().id.load(SeqCst));
+ }
+ }
+ }
+}
diff --git a/src/util/slab/addr.rs b/src/util/slab/addr.rs
deleted file mode 100644
index c14e32e..0000000
--- a/src/util/slab/addr.rs
+++ /dev/null
@@ -1,154 +0,0 @@
-//! Tracks the location of an entry in a slab.
-//!
-//! # Index packing
-//!
-//! A slab index consists of multiple indices packed into a single `usize` value
-//! that correspond to different parts of the slab.
-//!
-//! The least significant `MAX_PAGES + INITIAL_PAGE_SIZE.trailing_zeros() + 1`
-//! bits store the address within a shard, starting at 0 for the first slot on
-//! the first page. To index a slot within a shard, we first find the index of
-//! the page that the address falls on, and then the offset of the slot within
-//! that page.
-//!
-//! Since every page is twice as large as the previous page, and all page sizes
-//! are powers of two, we can determine the page index that contains a given
-//! address by shifting the address down by the smallest page size and looking
-//! at how many twos places necessary to represent that number, telling us what
-//! power of two page size it fits inside of. We can determine the number of
-//! twos places by counting the number of leading zeros (unused twos places) in
-//! the number's binary representation, and subtracting that count from the
-//! total number of bits in a word.
-//!
-//! Once we know what page contains an address, we can subtract the size of all
-//! previous pages from the address to determine the offset within the page.
-//!
-//! After the page address, the next `MAX_THREADS.trailing_zeros() + 1` least
-//! significant bits are the thread ID. These are used to index the array of
-//! shards to find which shard a slot belongs to. If an entry is being removed
-//! and the thread ID of its index matches that of the current thread, we can
-//! use the `remove_local` fast path; otherwise, we have to use the synchronized
-//! `remove_remote` path.
-//!
-//! Finally, a generation value is packed into the index. The `RESERVED_BITS`
-//! most significant bits are left unused, and the remaining bits between the
-//! last bit of the thread ID and the first reserved bit are used to store the
-//! generation. The generation is used as part of an atomic read-modify-write
-//! loop every time a `ScheduledIo`'s readiness is modified, or when the
-//! resource is removed, to guard against the ABA problem.
-//!
-//! Visualized:
-//!
-//! ```text
-//! ┌──────────┬───────────────┬──────────────────┬──────────────────────────┐
-//! │ reserved │ generation │ thread ID │ address │
-//! └▲─────────┴▲──────────────┴▲─────────────────┴▲────────────────────────▲┘
-//! │ │ │ │ │
-//! bits(usize) │ bits(MAX_THREADS) │ 0
-//! │ │
-//! bits(usize) - RESERVED MAX_PAGES + bits(INITIAL_PAGE_SIZE)
-//! ```
-
-use crate::util::bit;
-use crate::util::slab::{Generation, INITIAL_PAGE_SIZE, MAX_PAGES, MAX_THREADS};
-
-use std::usize;
-
-/// References the location at which an entry is stored in a slab.
-#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-pub(crate) struct Address(usize);
-
-const PAGE_INDEX_SHIFT: u32 = INITIAL_PAGE_SIZE.trailing_zeros() + 1;
-
-/// Address in the shard
-const SLOT: bit::Pack = bit::Pack::least_significant(MAX_PAGES as u32 + PAGE_INDEX_SHIFT);
-
-/// Masks the thread identifier
-const THREAD: bit::Pack = SLOT.then(MAX_THREADS.trailing_zeros() + 1);
-
-/// Masks the generation
-const GENERATION: bit::Pack = THREAD
- .then(bit::pointer_width().wrapping_sub(RESERVED.width() + THREAD.width() + SLOT.width()));
-
-// Chosen arbitrarily
-const RESERVED: bit::Pack = bit::Pack::most_significant(5);
-
-impl Address {
- /// Represents no entry, picked to avoid collision with Mio's internals.
- /// This value should not be passed to mio.
- pub(crate) const NULL: usize = usize::MAX >> 1;
-
- /// Re-exported by `Generation`.
- pub(super) const GENERATION_WIDTH: u32 = GENERATION.width();
-
- pub(super) fn new(shard_index: usize, generation: Generation) -> Address {
- let mut repr = 0;
-
- repr = SLOT.pack(shard_index, repr);
- repr = GENERATION.pack(generation.to_usize(), repr);
-
- Address(repr)
- }
-
- /// Convert from a `usize` representation.
- pub(crate) fn from_usize(src: usize) -> Address {
- assert_ne!(src, Self::NULL);
-
- Address(src)
- }
-
- /// Convert to a `usize` representation
- pub(crate) fn to_usize(self) -> usize {
- self.0
- }
-
- pub(crate) fn generation(self) -> Generation {
- Generation::new(GENERATION.unpack(self.0))
- }
-
- /// Returns the page index
- pub(super) fn page(self) -> usize {
- // Since every page is twice as large as the previous page, and all page
- // sizes are powers of two, we can determine the page index that
- // contains a given address by shifting the address down by the smallest
- // page size and looking at how many twos places necessary to represent
- // that number, telling us what power of two page size it fits inside
- // of. We can determine the number of twos places by counting the number
- // of leading zeros (unused twos places) in the number's binary
- // representation, and subtracting that count from the total number of
- // bits in a word.
- let slot_shifted = (self.slot() + INITIAL_PAGE_SIZE) >> PAGE_INDEX_SHIFT;
- (bit::pointer_width() - slot_shifted.leading_zeros()) as usize
- }
-
- /// Returns the slot index
- pub(super) fn slot(self) -> usize {
- SLOT.unpack(self.0)
- }
-}
-
-#[cfg(test)]
-cfg_not_loom! {
- use proptest::proptest;
-
- #[test]
- fn test_pack_format() {
- assert_eq!(5, RESERVED.width());
- assert_eq!(0b11111, RESERVED.max_value());
- }
-
- proptest! {
- #[test]
- fn address_roundtrips(
- slot in 0usize..SLOT.max_value(),
- generation in 0usize..Generation::MAX,
- ) {
- let address = Address::new(slot, Generation::new(generation));
- // Round trip
- let address = Address::from_usize(address.to_usize());
-
- assert_eq!(address.slot(), slot);
- assert_eq!(address.generation().to_usize(), generation);
- }
- }
-}
diff --git a/src/util/slab/entry.rs b/src/util/slab/entry.rs
deleted file mode 100644
index 2e0b10b..0000000
--- a/src/util/slab/entry.rs
+++ /dev/null
@@ -1,7 +0,0 @@
-use crate::util::slab::Generation;
-
-pub(crate) trait Entry: Default {
- fn generation(&self) -> Generation;
-
- fn reset(&self, generation: Generation) -> bool;
-}
diff --git a/src/util/slab/generation.rs b/src/util/slab/generation.rs
deleted file mode 100644
index 4b16b2c..0000000
--- a/src/util/slab/generation.rs
+++ /dev/null
@@ -1,32 +0,0 @@
-use crate::util::bit;
-use crate::util::slab::Address;
-
-/// An mutation identifier for a slot in the slab. The generation helps prevent
-/// accessing an entry with an outdated token.
-#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)]
-pub(crate) struct Generation(usize);
-
-impl Generation {
- pub(crate) const WIDTH: u32 = Address::GENERATION_WIDTH;
-
- pub(super) const MAX: usize = bit::mask_for(Address::GENERATION_WIDTH);
-
- /// Create a new generation
- ///
- /// # Panics
- ///
- /// Panics if `value` is greater than max generation.
- pub(crate) fn new(value: usize) -> Generation {
- assert!(value <= Self::MAX);
- Generation(value)
- }
-
- /// Returns the next generation value
- pub(crate) fn next(self) -> Generation {
- Generation((self.0 + 1) & Self::MAX)
- }
-
- pub(crate) fn to_usize(self) -> usize {
- self.0
- }
-}
diff --git a/src/util/slab/mod.rs b/src/util/slab/mod.rs
deleted file mode 100644
index 5082970..0000000
--- a/src/util/slab/mod.rs
+++ /dev/null
@@ -1,107 +0,0 @@
-//! A lock-free concurrent slab.
-
-mod addr;
-pub(crate) use addr::Address;
-
-mod entry;
-pub(crate) use entry::Entry;
-
-mod generation;
-pub(crate) use generation::Generation;
-
-mod page;
-
-mod shard;
-use shard::Shard;
-
-mod slot;
-use slot::Slot;
-
-mod stack;
-use stack::TransferStack;
-
-#[cfg(all(loom, test))]
-mod tests;
-
-use crate::loom::sync::Mutex;
-use crate::util::bit;
-
-use std::fmt;
-
-#[cfg(target_pointer_width = "64")]
-const MAX_THREADS: usize = 4096;
-
-#[cfg(target_pointer_width = "32")]
-const MAX_THREADS: usize = 2048;
-
-/// Max number of pages per slab
-const MAX_PAGES: usize = bit::pointer_width() as usize / 4;
-
-cfg_not_loom! {
- /// Size of first page
- const INITIAL_PAGE_SIZE: usize = 32;
-}
-
-cfg_loom! {
- const INITIAL_PAGE_SIZE: usize = 2;
-}
-
-/// A sharded slab.
-pub(crate) struct Slab<T> {
- // Signal shard for now. Eventually there will be more.
- shard: Shard<T>,
- local: Mutex<()>,
-}
-
-unsafe impl<T: Send> Send for Slab<T> {}
-unsafe impl<T: Sync> Sync for Slab<T> {}
-
-impl<T: Entry> Slab<T> {
- /// Returns a new slab with the default configuration parameters.
- pub(crate) fn new() -> Slab<T> {
- Slab {
- shard: Shard::new(),
- local: Mutex::new(()),
- }
- }
-
- /// allocs a value into the slab, returning a key that can be used to
- /// access it.
- ///
- /// If this function returns `None`, then the shard for the current thread
- /// is full and no items can be added until some are removed, or the maximum
- /// number of shards has been reached.
- pub(crate) fn alloc(&self) -> Option<Address> {
- // we must lock the slab to alloc an item.
- let _local = self.local.lock().unwrap();
- self.shard.alloc()
- }
-
- /// Removes the value associated with the given key from the slab.
- pub(crate) fn remove(&self, idx: Address) {
- // try to lock the slab so that we can use `remove_local`.
- let lock = self.local.try_lock();
-
- // if we were able to lock the slab, we are "local" and can use the fast
- // path; otherwise, we will use `remove_remote`.
- if lock.is_ok() {
- self.shard.remove_local(idx)
- } else {
- self.shard.remove_remote(idx)
- }
- }
-
- /// Return a reference to the value associated with the given key.
- ///
- /// If the slab does not contain a value for the given key, `None` is
- /// returned instead.
- pub(crate) fn get(&self, token: Address) -> Option<&T> {
- self.shard.get(token)
- }
-}
-
-impl<T> fmt::Debug for Slab<T> {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("Slab").field("shard", &self.shard).finish()
- }
-}
diff --git a/src/util/slab/page.rs b/src/util/slab/page.rs
deleted file mode 100644
index 0000e93..0000000
--- a/src/util/slab/page.rs
+++ /dev/null
@@ -1,187 +0,0 @@
-use crate::loom::cell::UnsafeCell;
-use crate::util::slab::{Address, Entry, Slot, TransferStack, INITIAL_PAGE_SIZE};
-
-use std::fmt;
-
-/// Data accessed only by the thread that owns the shard.
-pub(crate) struct Local {
- head: UnsafeCell<usize>,
-}
-
-/// Data accessed by any thread.
-pub(crate) struct Shared<T> {
- remote: TransferStack,
- size: usize,
- prev_sz: usize,
- slab: UnsafeCell<Option<Box<[Slot<T>]>>>,
-}
-
-/// Returns the size of the page at index `n`
-pub(super) fn size(n: usize) -> usize {
- INITIAL_PAGE_SIZE << n
-}
-
-impl Local {
- pub(crate) fn new() -> Self {
- Self {
- head: UnsafeCell::new(0),
- }
- }
-
- fn head(&self) -> usize {
- self.head.with(|head| unsafe { *head })
- }
-
- fn set_head(&self, new_head: usize) {
- self.head.with_mut(|head| unsafe {
- *head = new_head;
- })
- }
-}
-
-impl<T: Entry> Shared<T> {
- pub(crate) fn new(size: usize, prev_sz: usize) -> Shared<T> {
- Self {
- prev_sz,
- size,
- remote: TransferStack::new(),
- slab: UnsafeCell::new(None),
- }
- }
-
- /// Allocates storage for this page if it does not allready exist.
- ///
- /// This requires unique access to the page (e.g. it is called from the
- /// thread that owns the page, or, in the case of `SingleShard`, while the
- /// lock is held). In order to indicate this, a reference to the page's
- /// `Local` data is taken by this function; the `Local` argument is not
- /// actually used, but requiring it ensures that this is only called when
- /// local access is held.
- #[cold]
- fn alloc_page(&self, _: &Local) {
- debug_assert!(self.slab.with(|s| unsafe { (*s).is_none() }));
-
- let mut slab = Vec::with_capacity(self.size);
- slab.extend((1..self.size).map(Slot::new));
- slab.push(Slot::new(Address::NULL));
-
- self.slab.with_mut(|s| {
- // this mut access is safe — it only occurs to initially
- // allocate the page, which only happens on this thread; if the
- // page has not yet been allocated, other threads will not try
- // to access it yet.
- unsafe {
- *s = Some(slab.into_boxed_slice());
- }
- });
- }
-
- pub(crate) fn alloc(&self, local: &Local) -> Option<Address> {
- let head = local.head();
-
- // are there any items on the local free list? (fast path)
- let head = if head < self.size {
- head
- } else {
- // if the local free list is empty, pop all the items on the remote
- // free list onto the local free list.
- self.remote.pop_all()?
- };
-
- // if the head is still null, both the local and remote free lists are
- // empty --- we can't fit any more items on this page.
- if head == Address::NULL {
- return None;
- }
-
- // do we need to allocate storage for this page?
- let page_needs_alloc = self.slab.with(|s| unsafe { (*s).is_none() });
- if page_needs_alloc {
- self.alloc_page(local);
- }
-
- let gen = self.slab.with(|slab| {
- let slab = unsafe { &*(slab) }
- .as_ref()
- .expect("page must have been allocated to alloc!");
-
- let slot = &slab[head];
-
- local.set_head(slot.next());
- slot.generation()
- });
-
- let index = head + self.prev_sz;
-
- Some(Address::new(index, gen))
- }
-
- pub(crate) fn get(&self, addr: Address) -> Option<&T> {
- let page_offset = addr.slot() - self.prev_sz;
-
- self.slab
- .with(|slab| unsafe { &*slab }.as_ref()?.get(page_offset))
- .map(|slot| slot.get())
- }
-
- pub(crate) fn remove_local(&self, local: &Local, addr: Address) {
- let offset = addr.slot() - self.prev_sz;
-
- self.slab.with(|slab| {
- let slab = unsafe { &*slab }.as_ref();
-
- let slot = if let Some(slot) = slab.and_then(|slab| slab.get(offset)) {
- slot
- } else {
- return;
- };
-
- if slot.reset(addr.generation()) {
- slot.set_next(local.head());
- local.set_head(offset);
- }
- })
- }
-
- pub(crate) fn remove_remote(&self, addr: Address) {
- let offset = addr.slot() - self.prev_sz;
-
- self.slab.with(|slab| {
- let slab = unsafe { &*slab }.as_ref();
-
- let slot = if let Some(slot) = slab.and_then(|slab| slab.get(offset)) {
- slot
- } else {
- return;
- };
-
- if !slot.reset(addr.generation()) {
- return;
- }
-
- self.remote.push(offset, |next| slot.set_next(next));
- })
- }
-}
-
-impl fmt::Debug for Local {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.head.with(|head| {
- let head = unsafe { *head };
- f.debug_struct("Local")
- .field("head", &format_args!("{:#0x}", head))
- .finish()
- })
- }
-}
-
-impl<T> fmt::Debug for Shared<T> {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("Shared")
- .field("remote", &self.remote)
- .field("prev_sz", &self.prev_sz)
- .field("size", &self.size)
- // .field("slab", &self.slab)
- .finish()
- }
-}
diff --git a/src/util/slab/shard.rs b/src/util/slab/shard.rs
deleted file mode 100644
index eaca6f6..0000000
--- a/src/util/slab/shard.rs
+++ /dev/null
@@ -1,105 +0,0 @@
-use crate::util::slab::{page, Address, Entry, MAX_PAGES};
-
-use std::fmt;
-
-// ┌─────────────┐ ┌────────┐
-// │ page 1 │ │ │
-// ├─────────────┤ ┌───▶│ next──┼─┐
-// │ page 2 │ │ ├────────┤ │
-// │ │ │ │XXXXXXXX│ │
-// │ local_free──┼─┘ ├────────┤ │
-// │ global_free─┼─┐ │ │◀┘
-// ├─────────────┤ └───▶│ next──┼─┐
-// │ page 3 │ ├────────┤ │
-// └─────────────┘ │XXXXXXXX│ │
-// ... ├────────┤ │
-// ┌─────────────┐ │XXXXXXXX│ │
-// │ page n │ ├────────┤ │
-// └─────────────┘ │ │◀┘
-// │ next──┼───▶
-// ├────────┤
-// │XXXXXXXX│
-// └────────┘
-// ...
-pub(super) struct Shard<T> {
- /// The local free list for each page.
- ///
- /// These are only ever accessed from this shard's thread, so they are
- /// stored separately from the shared state for the page that can be
- /// accessed concurrently, to minimize false sharing.
- local: Box<[page::Local]>,
- /// The shared state for each page in this shard.
- ///
- /// This consists of the page's metadata (size, previous size), remote free
- /// list, and a pointer to the actual array backing that page.
- shared: Box<[page::Shared<T>]>,
-}
-
-impl<T: Entry> Shard<T> {
- pub(super) fn new() -> Shard<T> {
- let mut total_sz = 0;
- let shared = (0..MAX_PAGES)
- .map(|page_num| {
- let sz = page::size(page_num);
- let prev_sz = total_sz;
- total_sz += sz;
- page::Shared::new(sz, prev_sz)
- })
- .collect();
-
- let local = (0..MAX_PAGES).map(|_| page::Local::new()).collect();
-
- Shard { local, shared }
- }
-
- pub(super) fn alloc(&self) -> Option<Address> {
- // Can we fit the value into an existing page?
- for (page_idx, page) in self.shared.iter().enumerate() {
- let local = self.local(page_idx);
-
- if let Some(page_offset) = page.alloc(local) {
- return Some(page_offset);
- }
- }
-
- None
- }
-
- pub(super) fn get(&self, addr: Address) -> Option<&T> {
- let page_idx = addr.page();
-
- if page_idx > self.shared.len() {
- return None;
- }
-
- self.shared[page_idx].get(addr)
- }
-
- /// Remove an item on the shard's local thread.
- pub(super) fn remove_local(&self, addr: Address) {
- let page_idx = addr.page();
-
- if let Some(page) = self.shared.get(page_idx) {
- page.remove_local(self.local(page_idx), addr);
- }
- }
-
- /// Remove an item, while on a different thread from the shard's local thread.
- pub(super) fn remove_remote(&self, addr: Address) {
- if let Some(page) = self.shared.get(addr.page()) {
- page.remove_remote(addr);
- }
- }
-
- fn local(&self, i: usize) -> &page::Local {
- &self.local[i]
- }
-}
-
-impl<T> fmt::Debug for Shard<T> {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("Shard")
- .field("shared", &self.shared)
- .finish()
- }
-}
diff --git a/src/util/slab/slot.rs b/src/util/slab/slot.rs
deleted file mode 100644
index 0608b26..0000000
--- a/src/util/slab/slot.rs
+++ /dev/null
@@ -1,42 +0,0 @@
-use crate::loom::cell::UnsafeCell;
-use crate::util::slab::{Entry, Generation};
-
-/// Stores an entry in the slab.
-pub(super) struct Slot<T> {
- next: UnsafeCell<usize>,
- entry: T,
-}
-
-impl<T: Entry> Slot<T> {
- /// Initialize a new `Slot` linked to `next`.
- ///
- /// The entry is initialized to a default value.
- pub(super) fn new(next: usize) -> Slot<T> {
- Slot {
- next: UnsafeCell::new(next),
- entry: T::default(),
- }
- }
-
- pub(super) fn get(&self) -> &T {
- &self.entry
- }
-
- pub(super) fn generation(&self) -> Generation {
- self.entry.generation()
- }
-
- pub(super) fn reset(&self, generation: Generation) -> bool {
- self.entry.reset(generation)
- }
-
- pub(super) fn next(&self) -> usize {
- self.next.with(|next| unsafe { *next })
- }
-
- pub(super) fn set_next(&self, next: usize) {
- self.next.with_mut(|n| unsafe {
- (*n) = next;
- })
- }
-}
diff --git a/src/util/slab/stack.rs b/src/util/slab/stack.rs
deleted file mode 100644
index 0ae0d71..0000000
--- a/src/util/slab/stack.rs
+++ /dev/null
@@ -1,58 +0,0 @@
-use crate::loom::sync::atomic::AtomicUsize;
-use crate::util::slab::Address;
-
-use std::fmt;
-use std::sync::atomic::Ordering;
-use std::usize;
-
-pub(super) struct TransferStack {
- head: AtomicUsize,
-}
-
-impl TransferStack {
- pub(super) fn new() -> Self {
- Self {
- head: AtomicUsize::new(Address::NULL),
- }
- }
-
- pub(super) fn pop_all(&self) -> Option<usize> {
- let val = self.head.swap(Address::NULL, Ordering::Acquire);
-
- if val == Address::NULL {
- None
- } else {
- Some(val)
- }
- }
-
- pub(super) fn push(&self, value: usize, before: impl Fn(usize)) {
- let mut next = self.head.load(Ordering::Relaxed);
-
- loop {
- before(next);
-
- match self
- .head
- .compare_exchange(next, value, Ordering::AcqRel, Ordering::Acquire)
- {
- // lost the race!
- Err(actual) => next = actual,
- Ok(_) => return,
- }
- }
- }
-}
-
-impl fmt::Debug for TransferStack {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- // Loom likes to dump all its internal state in `fmt::Debug` impls, so
- // we override this to just print the current value in tests.
- f.debug_struct("TransferStack")
- .field(
- "head",
- &format_args!("{:#x}", self.head.load(Ordering::Relaxed)),
- )
- .finish()
- }
-}
diff --git a/src/util/slab/tests/loom_slab.rs b/src/util/slab/tests/loom_slab.rs
deleted file mode 100644
index 48e94f0..0000000
--- a/src/util/slab/tests/loom_slab.rs
+++ /dev/null
@@ -1,327 +0,0 @@
-use crate::io::driver::ScheduledIo;
-use crate::util::slab::{Address, Slab};
-
-use loom::sync::{Arc, Condvar, Mutex};
-use loom::thread;
-
-#[test]
-fn local_remove() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
-
- let s = slab.clone();
- let t1 = thread::spawn(move || {
- let idx = store_val(&s, 1);
- assert_eq!(get_val(&s, idx), Some(1));
- s.remove(idx);
- assert_eq!(get_val(&s, idx), None);
- let idx = store_val(&s, 2);
- assert_eq!(get_val(&s, idx), Some(2));
- s.remove(idx);
- assert_eq!(get_val(&s, idx), None);
- });
-
- let s = slab.clone();
- let t2 = thread::spawn(move || {
- let idx = store_val(&s, 3);
- assert_eq!(get_val(&s, idx), Some(3));
- s.remove(idx);
- assert_eq!(get_val(&s, idx), None);
- let idx = store_val(&s, 4);
- s.remove(idx);
- assert_eq!(get_val(&s, idx), None);
- });
-
- let s = slab;
- let idx1 = store_val(&s, 5);
- assert_eq!(get_val(&s, idx1), Some(5));
- let idx2 = store_val(&s, 6);
- assert_eq!(get_val(&s, idx2), Some(6));
- s.remove(idx1);
- assert_eq!(get_val(&s, idx1), None);
- assert_eq!(get_val(&s, idx2), Some(6));
- s.remove(idx2);
- assert_eq!(get_val(&s, idx2), None);
-
- t1.join().expect("thread 1 should not panic");
- t2.join().expect("thread 2 should not panic");
- });
-}
-
-#[test]
-fn remove_remote() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
-
- let idx1 = store_val(&slab, 1);
- assert_eq!(get_val(&slab, idx1), Some(1));
-
- let idx2 = store_val(&slab, 2);
- assert_eq!(get_val(&slab, idx2), Some(2));
-
- let idx3 = store_val(&slab, 3);
- assert_eq!(get_val(&slab, idx3), Some(3));
-
- let s = slab.clone();
- let t1 = thread::spawn(move || {
- assert_eq!(get_val(&s, idx2), Some(2));
- s.remove(idx2);
- assert_eq!(get_val(&s, idx2), None);
- });
-
- let s = slab.clone();
- let t2 = thread::spawn(move || {
- assert_eq!(get_val(&s, idx3), Some(3));
- s.remove(idx3);
- assert_eq!(get_val(&s, idx3), None);
- });
-
- t1.join().expect("thread 1 should not panic");
- t2.join().expect("thread 2 should not panic");
-
- assert_eq!(get_val(&slab, idx1), Some(1));
- assert_eq!(get_val(&slab, idx2), None);
- assert_eq!(get_val(&slab, idx3), None);
- });
-}
-
-#[test]
-fn remove_remote_and_reuse() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
-
- let idx1 = store_val(&slab, 1);
- let idx2 = store_val(&slab, 2);
-
- assert_eq!(get_val(&slab, idx1), Some(1));
- assert_eq!(get_val(&slab, idx2), Some(2));
-
- let s = slab.clone();
- let t1 = thread::spawn(move || {
- s.remove(idx1);
- let value = get_val(&s, idx1);
-
- // We may or may not see the new value yet, depending on when
- // this occurs, but we must either see the new value or `None`;
- // the old value has been removed!
- assert!(value == None || value == Some(3));
- });
-
- let idx3 = store_when_free(&slab, 3);
- t1.join().expect("thread 1 should not panic");
-
- assert_eq!(get_val(&slab, idx3), Some(3));
- assert_eq!(get_val(&slab, idx2), Some(2));
- });
-}
-
-#[test]
-fn concurrent_alloc_remove() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
- let pair = Arc::new((Mutex::new(None), Condvar::new()));
-
- let slab2 = slab.clone();
- let pair2 = pair.clone();
- let remover = thread::spawn(move || {
- let (lock, cvar) = &*pair2;
- for _ in 0..2 {
- let mut next = lock.lock().unwrap();
- while next.is_none() {
- next = cvar.wait(next).unwrap();
- }
- let key = next.take().unwrap();
- slab2.remove(key);
- assert_eq!(get_val(&slab2, key), None);
- cvar.notify_one();
- }
- });
-
- let (lock, cvar) = &*pair;
- for i in 0..2 {
- let key = store_val(&slab, i);
-
- let mut next = lock.lock().unwrap();
- *next = Some(key);
- cvar.notify_one();
-
- // Wait for the item to be removed.
- while next.is_some() {
- next = cvar.wait(next).unwrap();
- }
-
- assert_eq!(get_val(&slab, key), None);
- }
-
- remover.join().unwrap();
- })
-}
-
-#[test]
-fn concurrent_remove_remote_and_reuse() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
-
- let idx1 = store_val(&slab, 1);
- let idx2 = store_val(&slab, 2);
-
- assert_eq!(get_val(&slab, idx1), Some(1));
- assert_eq!(get_val(&slab, idx2), Some(2));
-
- let s = slab.clone();
- let s2 = slab.clone();
- let t1 = thread::spawn(move || {
- s.remove(idx1);
- });
-
- let t2 = thread::spawn(move || {
- s2.remove(idx2);
- });
-
- let idx3 = store_when_free(&slab, 3);
- t1.join().expect("thread 1 should not panic");
- t2.join().expect("thread 1 should not panic");
-
- assert!(get_val(&slab, idx1).is_none());
- assert!(get_val(&slab, idx2).is_none());
- assert_eq!(get_val(&slab, idx3), Some(3));
- });
-}
-
-#[test]
-fn alloc_remove_get() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
- let pair = Arc::new((Mutex::new(None), Condvar::new()));
-
- let slab2 = slab.clone();
- let pair2 = pair.clone();
- let t1 = thread::spawn(move || {
- let slab = slab2;
- let (lock, cvar) = &*pair2;
- // allocate one entry just so that we have to use the final one for
- // all future allocations.
- let _key0 = store_val(&slab, 0);
- let key = store_val(&slab, 1);
-
- let mut next = lock.lock().unwrap();
- *next = Some(key);
- cvar.notify_one();
- // remove the second entry
- slab.remove(key);
- // store a new readiness at the same location (since the slab
- // already has an entry in slot 0)
- store_val(&slab, 2);
- });
-
- let (lock, cvar) = &*pair;
- // wait for the second entry to be stored...
- let mut next = lock.lock().unwrap();
- while next.is_none() {
- next = cvar.wait(next).unwrap();
- }
- let key = next.unwrap();
-
- // our generation will be stale when the second store occurs at that
- // index, we must not see the value of that store.
- let val = get_val(&slab, key);
- assert_ne!(val, Some(2), "generation must have advanced!");
-
- t1.join().unwrap();
- })
-}
-
-#[test]
-fn alloc_remove_set() {
- loom::model(|| {
- let slab = Arc::new(Slab::new());
- let pair = Arc::new((Mutex::new(None), Condvar::new()));
-
- let slab2 = slab.clone();
- let pair2 = pair.clone();
- let t1 = thread::spawn(move || {
- let slab = slab2;
- let (lock, cvar) = &*pair2;
- // allocate one entry just so that we have to use the final one for
- // all future allocations.
- let _key0 = store_val(&slab, 0);
- let key = store_val(&slab, 1);
-
- let mut next = lock.lock().unwrap();
- *next = Some(key);
- cvar.notify_one();
-
- slab.remove(key);
- // remove the old entry and insert a new one, with a new generation.
- let key2 = slab.alloc().expect("store key 2");
- // after the remove, we must not see the value written with the
- // stale index.
- assert_eq!(
- get_val(&slab, key),
- None,
- "stale set must no longer be visible"
- );
- assert_eq!(get_val(&slab, key2), Some(0));
- key2
- });
-
- let (lock, cvar) = &*pair;
-
- // wait for the second entry to be stored. the index we get from the
- // other thread may become stale after a write.
- let mut next = lock.lock().unwrap();
- while next.is_none() {
- next = cvar.wait(next).unwrap();
- }
- let key = next.unwrap();
-
- // try to write to the index with our generation
- slab.get(key).map(|val| val.set_readiness(key, |_| 2));
-
- let key2 = t1.join().unwrap();
- // after the remove, we must not see the value written with the
- // stale index either.
- assert_eq!(
- get_val(&slab, key),
- None,
- "stale set must no longer be visible"
- );
- assert_eq!(get_val(&slab, key2), Some(0));
- });
-}
-
-fn get_val(slab: &Arc<Slab<ScheduledIo>>, address: Address) -> Option<usize> {
- slab.get(address).and_then(|s| s.get_readiness(address))
-}
-
-fn store_val(slab: &Arc<Slab<ScheduledIo>>, readiness: usize) -> Address {
- let key = slab.alloc().expect("allocate slot");
-
- if let Some(slot) = slab.get(key) {
- slot.set_readiness(key, |_| readiness)
- .expect("generation should still be valid!");
- } else {
- panic!("slab did not contain a value for {:?}", key);
- }
-
- key
-}
-
-fn store_when_free(slab: &Arc<Slab<ScheduledIo>>, readiness: usize) -> Address {
- let key = loop {
- if let Some(key) = slab.alloc() {
- break key;
- }
-
- thread::yield_now();
- };
-
- if let Some(slot) = slab.get(key) {
- slot.set_readiness(key, |_| readiness)
- .expect("generation should still be valid!");
- } else {
- panic!("slab did not contain a value for {:?}", key);
- }
-
- key
-}
diff --git a/src/util/slab/tests/loom_stack.rs b/src/util/slab/tests/loom_stack.rs
deleted file mode 100644
index 47ad46d..0000000
--- a/src/util/slab/tests/loom_stack.rs
+++ /dev/null
@@ -1,88 +0,0 @@
-use crate::util::slab::TransferStack;
-
-use loom::cell::UnsafeCell;
-use loom::sync::Arc;
-use loom::thread;
-
-#[test]
-fn transfer_stack() {
- loom::model(|| {
- let causalities = [UnsafeCell::new(None), UnsafeCell::new(None)];
- let shared = Arc::new((causalities, TransferStack::new()));
- let shared1 = shared.clone();
- let shared2 = shared.clone();
-
- // Spawn two threads that both try to push to the stack.
- let t1 = thread::spawn(move || {
- let (causalities, stack) = &*shared1;
- stack.push(0, |prev| {
- causalities[0].with_mut(|c| unsafe {
- *c = Some(prev);
- });
- });
- });
-
- let t2 = thread::spawn(move || {
- let (causalities, stack) = &*shared2;
- stack.push(1, |prev| {
- causalities[1].with_mut(|c| unsafe {
- *c = Some(prev);
- });
- });
- });
-
- let (causalities, stack) = &*shared;
-
- // Try to pop from the stack...
- let mut idx = stack.pop_all();
- while idx == None {
- idx = stack.pop_all();
- thread::yield_now();
- }
- let idx = idx.unwrap();
-
- let saw_both = causalities[idx].with(|val| {
- let val = unsafe { *val };
- assert!(
- val.is_some(),
- "UnsafeCell write must happen-before index is pushed to the stack!",
- );
- // were there two entries in the stack? if so, check that
- // both saw a write.
- if let Some(c) = causalities.get(val.unwrap()) {
- c.with(|val| {
- let val = unsafe { *val };
- assert!(
- val.is_some(),
- "UnsafeCell write must happen-before index is pushed to the stack!",
- );
- });
- true
- } else {
- false
- }
- });
-
- // We only saw one push. Ensure that the other push happens too.
- if !saw_both {
- // Try to pop from the stack...
- let mut idx = stack.pop_all();
- while idx == None {
- idx = stack.pop_all();
- thread::yield_now();
- }
- let idx = idx.unwrap();
-
- causalities[idx].with(|val| {
- let val = unsafe { *val };
- assert!(
- val.is_some(),
- "UnsafeCell write must happen-before index is pushed to the stack!",
- );
- });
- }
-
- t1.join().unwrap();
- t2.join().unwrap();
- });
-}
diff --git a/src/util/slab/tests/mod.rs b/src/util/slab/tests/mod.rs
deleted file mode 100644
index 7f79354..0000000
--- a/src/util/slab/tests/mod.rs
+++ /dev/null
@@ -1,2 +0,0 @@
-mod loom_slab;
-mod loom_stack;
diff --git a/src/util/trace.rs b/src/util/trace.rs
index d8c6120..18956a3 100644
--- a/src/util/trace.rs
+++ b/src/util/trace.rs
@@ -1,5 +1,5 @@
cfg_trace! {
- cfg_rt_core! {
+ cfg_rt! {
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -47,7 +47,7 @@ cfg_trace! {
}
cfg_not_trace! {
- cfg_rt_core! {
+ cfg_rt! {
#[inline]
pub(crate) fn task<F>(task: F, _: &'static str) -> F {
// nop