aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoel Galenson <jgalenson@google.com>2021-06-22 09:28:11 -0700
committerJoel Galenson <jgalenson@google.com>2021-06-22 09:28:30 -0700
commitb53dd06ad19d902c2155b2f616725b14b423a776 (patch)
treeb92904fad1f985434ad7622c3960e6e670b1c918 /src
parent8c2e0e8165f4f0132f3a5c78337fbba15b102768 (diff)
downloadtokio-b53dd06ad19d902c2155b2f616725b14b423a776.tar.gz
Upgrade rust/crates/tokio to 1.7.1
Test: make Change-Id: I7ebd839df13023db6f2057e09d8b73967436b856
Diffstat (limited to 'src')
-rw-r--r--src/doc/mod.rs23
-rw-r--r--src/doc/os.rs26
-rw-r--r--src/doc/winapi.rs66
-rw-r--r--src/fs/file.rs191
-rw-r--r--src/future/mod.rs11
-rw-r--r--src/future/trace.rs11
-rw-r--r--src/io/async_write.rs6
-rw-r--r--src/io/util/async_read_ext.rs17
-rw-r--r--src/io/util/async_write_ext.rs12
-rw-r--r--src/io/util/buf_reader.rs2
-rw-r--r--src/io/util/buf_stream.rs32
-rw-r--r--src/lib.rs31
-rw-r--r--src/macros/cfg.rs11
-rw-r--r--src/macros/select.rs2
-rw-r--r--src/net/mod.rs4
-rw-r--r--src/net/tcp/socket.rs42
-rw-r--r--src/net/unix/ucred.rs2
-rw-r--r--src/net/windows/mod.rs3
-rw-r--r--src/net/windows/named_pipe.rs1199
-rw-r--r--src/runtime/basic_scheduler.rs73
-rw-r--r--src/runtime/blocking/pool.rs13
-rw-r--r--src/runtime/handle.rs12
-rw-r--r--src/runtime/queue.rs36
-rw-r--r--src/runtime/spawner.rs3
-rw-r--r--src/runtime/task/core.rs10
-rw-r--r--src/runtime/task/harness.rs7
-rw-r--r--src/runtime/task/mod.rs2
-rw-r--r--src/runtime/task/raw.rs2
-rw-r--r--src/runtime/task/state.rs11
-rw-r--r--src/runtime/task/waker.rs31
-rw-r--r--src/runtime/tests/loom_shutdown_join.rs28
-rw-r--r--src/runtime/tests/mod.rs1
-rw-r--r--src/runtime/tests/task.rs2
-rw-r--r--src/runtime/thread_pool/mod.rs10
-rw-r--r--src/runtime/thread_pool/worker.rs23
-rw-r--r--src/sync/mod.rs5
-rw-r--r--src/sync/mpsc/bounded.rs6
-rw-r--r--src/sync/mpsc/error.rs4
-rw-r--r--src/sync/notify.rs2
-rw-r--r--src/sync/semaphore.rs219
-rw-r--r--src/sync/task/atomic_waker.rs2
-rw-r--r--src/sync/watch.rs22
-rw-r--r--src/task/mod.rs2
-rw-r--r--src/time/clock.rs13
-rw-r--r--src/time/driver/entry.rs12
-rw-r--r--src/time/driver/mod.rs93
-rw-r--r--src/time/driver/wheel/mod.rs2
-rw-r--r--src/util/linked_list.rs1
-rw-r--r--src/util/wake.rs6
49 files changed, 2017 insertions, 327 deletions
diff --git a/src/doc/mod.rs b/src/doc/mod.rs
new file mode 100644
index 0000000..12c2247
--- /dev/null
+++ b/src/doc/mod.rs
@@ -0,0 +1,23 @@
+//! Types which are documented locally in the Tokio crate, but does not actually
+//! live here.
+//!
+//! **Note** this module is only visible on docs.rs, you cannot use it directly
+//! in your own code.
+
+/// The name of a type which is not defined here.
+///
+/// This is typically used as an alias for another type, like so:
+///
+/// ```rust,ignore
+/// /// See [some::other::location](https://example.com).
+/// type DEFINED_ELSEWHERE = crate::doc::NotDefinedHere;
+/// ```
+///
+/// This type is uninhabitable like the [`never` type] to ensure that no one
+/// will ever accidentally use it.
+///
+/// [`never` type]: https://doc.rust-lang.org/std/primitive.never.html
+pub enum NotDefinedHere {}
+
+pub mod os;
+pub mod winapi;
diff --git a/src/doc/os.rs b/src/doc/os.rs
new file mode 100644
index 0000000..0ddf869
--- /dev/null
+++ b/src/doc/os.rs
@@ -0,0 +1,26 @@
+//! See [std::os](https://doc.rust-lang.org/std/os/index.html).
+
+/// Platform-specific extensions to `std` for Windows.
+///
+/// See [std::os::windows](https://doc.rust-lang.org/std/os/windows/index.html).
+pub mod windows {
+ /// Windows-specific extensions to general I/O primitives.
+ ///
+ /// See [std::os::windows::io](https://doc.rust-lang.org/std/os/windows/io/index.html).
+ pub mod io {
+ /// See [std::os::windows::io::RawHandle](https://doc.rust-lang.org/std/os/windows/io/type.RawHandle.html)
+ pub type RawHandle = crate::doc::NotDefinedHere;
+
+ /// See [std::os::windows::io::AsRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html)
+ pub trait AsRawHandle {
+ /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html#tymethod.as_raw_handle)
+ fn as_raw_handle(&self) -> RawHandle;
+ }
+
+ /// See [std::os::windows::io::FromRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html)
+ pub trait FromRawHandle {
+ /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html#tymethod.from_raw_handle)
+ unsafe fn from_raw_handle(handle: RawHandle) -> Self;
+ }
+ }
+}
diff --git a/src/doc/winapi.rs b/src/doc/winapi.rs
new file mode 100644
index 0000000..be68749
--- /dev/null
+++ b/src/doc/winapi.rs
@@ -0,0 +1,66 @@
+//! See [winapi].
+//!
+//! [winapi]: https://docs.rs/winapi
+
+/// See [winapi::shared](https://docs.rs/winapi/*/winapi/shared/index.html).
+pub mod shared {
+ /// See [winapi::shared::winerror](https://docs.rs/winapi/*/winapi/shared/winerror/index.html).
+ #[allow(non_camel_case_types)]
+ pub mod winerror {
+ /// See [winapi::shared::winerror::ERROR_ACCESS_DENIED][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_ACCESS_DENIED.html
+ pub type ERROR_ACCESS_DENIED = crate::doc::NotDefinedHere;
+
+ /// See [winapi::shared::winerror::ERROR_PIPE_BUSY][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_PIPE_BUSY.html
+ pub type ERROR_PIPE_BUSY = crate::doc::NotDefinedHere;
+
+ /// See [winapi::shared::winerror::ERROR_MORE_DATA][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_MORE_DATA.html
+ pub type ERROR_MORE_DATA = crate::doc::NotDefinedHere;
+ }
+}
+
+/// See [winapi::um](https://docs.rs/winapi/*/winapi/um/index.html).
+pub mod um {
+ /// See [winapi::um::winbase](https://docs.rs/winapi/*/winapi/um/winbase/index.html).
+ #[allow(non_camel_case_types)]
+ pub mod winbase {
+ /// See [winapi::um::winbase::PIPE_TYPE_MESSAGE][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_MESSAGE.html
+ pub type PIPE_TYPE_MESSAGE = crate::doc::NotDefinedHere;
+
+ /// See [winapi::um::winbase::PIPE_TYPE_BYTE][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_BYTE.html
+ pub type PIPE_TYPE_BYTE = crate::doc::NotDefinedHere;
+
+ /// See [winapi::um::winbase::PIPE_CLIENT_END][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_CLIENT_END.html
+ pub type PIPE_CLIENT_END = crate::doc::NotDefinedHere;
+
+ /// See [winapi::um::winbase::PIPE_SERVER_END][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_SERVER_END.html
+ pub type PIPE_SERVER_END = crate::doc::NotDefinedHere;
+
+ /// See [winapi::um::winbase::SECURITY_IDENTIFICATION][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.SECURITY_IDENTIFICATION.html
+ pub type SECURITY_IDENTIFICATION = crate::doc::NotDefinedHere;
+ }
+
+ /// See [winapi::um::minwinbase](https://docs.rs/winapi/*/winapi/um/minwinbase/index.html).
+ #[allow(non_camel_case_types)]
+ pub mod minwinbase {
+ /// See [winapi::um::minwinbase::SECURITY_ATTRIBUTES][winapi]
+ ///
+ /// [winapi]: https://docs.rs/winapi/*/winapi/um/minwinbase/constant.SECURITY_ATTRIBUTES.html
+ pub type SECURITY_ATTRIBUTES = crate::doc::NotDefinedHere;
+ }
+}
diff --git a/src/fs/file.rs b/src/fs/file.rs
index abd6e8c..5c06e73 100644
--- a/src/fs/file.rs
+++ b/src/fs/file.rs
@@ -491,18 +491,14 @@ impl AsyncRead for File {
loop {
match inner.state {
Idle(ref mut buf_cell) => {
- let buf = buf_cell.as_mut().unwrap();
+ let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
buf.copy_to(dst);
+ *buf_cell = Some(buf);
return Ready(Ok(()));
}
- if let Some(x) = try_nonblocking_read(me.std.as_ref(), dst) {
- return Ready(x);
- }
-
- let mut buf = buf_cell.take().unwrap();
buf.ensure_capacity_for(dst);
let std = me.std.clone();
@@ -760,186 +756,3 @@ impl Inner {
}
}
}
-
-#[cfg(all(target_os = "linux", not(test)))]
-pub(crate) fn try_nonblocking_read(
- file: &crate::fs::sys::File,
- dst: &mut ReadBuf<'_>,
-) -> Option<std::io::Result<()>> {
- use std::sync::atomic::{AtomicBool, Ordering};
-
- static NONBLOCKING_READ_SUPPORTED: AtomicBool = AtomicBool::new(true);
- if !NONBLOCKING_READ_SUPPORTED.load(Ordering::Relaxed) {
- return None;
- }
- let out = preadv2::preadv2_safe(file, dst, -1, preadv2::RWF_NOWAIT);
- if let Err(err) = &out {
- match err.raw_os_error() {
- Some(libc::ENOSYS) => {
- NONBLOCKING_READ_SUPPORTED.store(false, Ordering::Relaxed);
- return None;
- }
- Some(libc::ENOTSUP) | Some(libc::EAGAIN) => return None,
- _ => {}
- }
- }
- Some(out)
-}
-
-#[cfg(any(not(target_os = "linux"), test))]
-pub(crate) fn try_nonblocking_read(
- _file: &crate::fs::sys::File,
- _dst: &mut ReadBuf<'_>,
-) -> Option<std::io::Result<()>> {
- None
-}
-
-#[cfg(target_os = "linux")]
-mod preadv2 {
- use libc::{c_int, c_long, c_void, iovec, off_t, ssize_t};
- use std::os::unix::prelude::AsRawFd;
-
- use crate::io::ReadBuf;
-
- pub(crate) fn preadv2_safe(
- file: &std::fs::File,
- dst: &mut ReadBuf<'_>,
- offset: off_t,
- flags: c_int,
- ) -> std::io::Result<()> {
- unsafe {
- /* We have to defend against buffer overflows manually here. The slice API makes
- * this fairly straightforward. */
- let unfilled = dst.unfilled_mut();
- let mut iov = iovec {
- iov_base: unfilled.as_mut_ptr() as *mut c_void,
- iov_len: unfilled.len(),
- };
- /* We take a File object rather than an fd as reading from a sensitive fd may confuse
- * other unsafe code that assumes that only they have access to that fd. */
- let bytes_read = preadv2(
- file.as_raw_fd(),
- &mut iov as *mut iovec as *const iovec,
- 1,
- offset,
- flags,
- );
- if bytes_read < 0 {
- Err(std::io::Error::last_os_error())
- } else {
- /* preadv2 returns the number of bytes read, e.g. the number of bytes that have
- * written into `unfilled`. So it's safe to assume that the data is now
- * initialised */
- dst.assume_init(bytes_read as usize);
- dst.advance(bytes_read as usize);
- Ok(())
- }
- }
- }
-
- #[cfg(test)]
- mod test {
- use super::*;
-
- #[test]
- fn test_preadv2_safe() {
- use std::io::{Seek, Write};
- use std::mem::MaybeUninit;
- use tempfile::tempdir;
-
- let tmp = tempdir().unwrap();
- let filename = tmp.path().join("file");
- const MESSAGE: &[u8] = b"Hello this is a test";
- {
- let mut f = std::fs::File::create(&filename).unwrap();
- f.write_all(MESSAGE).unwrap();
- }
- let f = std::fs::File::open(&filename).unwrap();
-
- let mut buf = [MaybeUninit::<u8>::new(0); 50];
- let mut br = ReadBuf::uninit(&mut buf);
-
- // Basic use:
- preadv2_safe(&f, &mut br, 0, 0).unwrap();
- assert_eq!(br.initialized().len(), MESSAGE.len());
- assert_eq!(br.filled(), MESSAGE);
-
- // Here we check that offset works, but also that appending to a non-empty buffer
- // behaves correctly WRT initialisation.
- preadv2_safe(&f, &mut br, 5, 0).unwrap();
- assert_eq!(br.initialized().len(), MESSAGE.len() * 2 - 5);
- assert_eq!(br.filled(), b"Hello this is a test this is a test".as_ref());
-
- // offset of -1 means use the current cursor. This has not been advanced by the
- // previous reads because we specified an offset there.
- preadv2_safe(&f, &mut br, -1, 0).unwrap();
- assert_eq!(br.remaining(), 0);
- assert_eq!(
- br.filled(),
- b"Hello this is a test this is a testHello this is a".as_ref()
- );
-
- // but the offset should have been advanced by that read
- br.clear();
- preadv2_safe(&f, &mut br, -1, 0).unwrap();
- assert_eq!(br.filled(), b" test");
-
- // This should be in cache, so RWF_NOWAIT should work, but it not being in cache
- // (EAGAIN) or not supported by the underlying filesystem (ENOTSUP) is fine too.
- br.clear();
- match preadv2_safe(&f, &mut br, 0, RWF_NOWAIT) {
- Ok(()) => assert_eq!(br.filled(), MESSAGE),
- Err(e) => assert!(matches!(
- e.raw_os_error(),
- Some(libc::ENOTSUP) | Some(libc::EAGAIN)
- )),
- }
-
- // Test handling large offsets
- {
- // I hope the underlying filesystem supports sparse files
- let mut w = std::fs::OpenOptions::new()
- .write(true)
- .open(&filename)
- .unwrap();
- w.set_len(0x1_0000_0000).unwrap();
- w.seek(std::io::SeekFrom::Start(0x1_0000_0000)).unwrap();
- w.write_all(b"This is a Large File").unwrap();
- }
-
- br.clear();
- preadv2_safe(&f, &mut br, 0x1_0000_0008, 0).unwrap();
- assert_eq!(br.filled(), b"a Large File");
- }
- }
-
- fn pos_to_lohi(offset: off_t) -> (c_long, c_long) {
- // 64-bit offset is split over high and low 32-bits on 32-bit architectures.
- // 64-bit architectures still have high and low arguments, but only the low
- // one is inspected. See pos_from_hilo in linux/fs/read_write.c.
- const HALF_LONG_BITS: usize = core::mem::size_of::<c_long>() * 8 / 2;
- (
- offset as c_long,
- // We want to shift this off_t value by size_of::<c_long>(). We can't do
- // it in one shift because if they're both 64-bits we'd be doing u64 >> 64
- // which is implementation defined. Instead do it in two halves:
- ((offset >> HALF_LONG_BITS) >> HALF_LONG_BITS) as c_long,
- )
- }
-
- pub(crate) const RWF_NOWAIT: c_int = 0x00000008;
- unsafe fn preadv2(
- fd: c_int,
- iov: *const iovec,
- iovcnt: c_int,
- offset: off_t,
- flags: c_int,
- ) -> ssize_t {
- // Call via libc::syscall rather than libc::preadv2. preadv2 is only supported by glibc
- // and only since v2.26. By using syscall we don't need to worry about compatiblity with
- // old glibc versions and it will work on Android and musl too. The downside is that you
- // can't use `LD_PRELOAD` tricks any more to intercept these calls.
- let (lo, hi) = pos_to_lohi(offset);
- libc::syscall(libc::SYS_preadv2, fd, iov, iovcnt, lo, hi, flags) as ssize_t
- }
-}
diff --git a/src/future/mod.rs b/src/future/mod.rs
index f7d93c9..96483ac 100644
--- a/src/future/mod.rs
+++ b/src/future/mod.rs
@@ -22,3 +22,14 @@ cfg_sync! {
mod block_on;
pub(crate) use block_on::block_on;
}
+
+cfg_trace! {
+ mod trace;
+ pub(crate) use trace::InstrumentedFuture as Future;
+}
+
+cfg_not_trace! {
+ cfg_rt! {
+ pub(crate) use std::future::Future;
+ }
+}
diff --git a/src/future/trace.rs b/src/future/trace.rs
new file mode 100644
index 0000000..28789a6
--- /dev/null
+++ b/src/future/trace.rs
@@ -0,0 +1,11 @@
+use std::future::Future;
+
+pub(crate) trait InstrumentedFuture: Future {
+ fn id(&self) -> Option<tracing::Id>;
+}
+
+impl<F: Future> InstrumentedFuture for tracing::instrument::Instrumented<F> {
+ fn id(&self) -> Option<tracing::Id> {
+ self.span().id()
+ }
+}
diff --git a/src/io/async_write.rs b/src/io/async_write.rs
index 569fb9c..7ec1a30 100644
--- a/src/io/async_write.rs
+++ b/src/io/async_write.rs
@@ -45,7 +45,11 @@ use std::task::{Context, Poll};
pub trait AsyncWrite {
/// Attempt to write bytes from `buf` into the object.
///
- /// On success, returns `Poll::Ready(Ok(num_bytes_written))`.
+ /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. If successful,
+ /// then it must be guaranteed that `n <= buf.len()`. A return value of `0`
+ /// typically means that the underlying object is no longer able to accept
+ /// bytes and will likely not be able to in the future as well, or that the
+ /// buffer provided is empty.
///
/// If the object is not ready for writing, the method returns
/// `Poll::Pending` and arranges for the current task (via
diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs
index e715f9d..878676f 100644
--- a/src/io/util/async_read_ext.rs
+++ b/src/io/util/async_read_ext.rs
@@ -108,6 +108,8 @@ cfg_io_util! {
/// This function does not provide any guarantees about whether it
/// completes immediately or asynchronously
///
+ /// # Return
+ ///
/// If the return value of this method is `Ok(n)`, then it must be
/// guaranteed that `0 <= n <= buf.len()`. A nonzero `n` value indicates
/// that the buffer `buf` has been filled in with `n` bytes of data from
@@ -180,9 +182,14 @@ cfg_io_util! {
///
/// # Return
///
- /// On a successful read, the number of read bytes is returned. If the
- /// supplied buffer is not empty and the function returns `Ok(0)` then
- /// the source has reached an "end-of-file" event.
+ /// A nonzero `n` value indicates that the buffer `buf` has been filled
+ /// in with `n` bytes of data from this source. If `n` is `0`, then it
+ /// can indicate one of two scenarios:
+ ///
+ /// 1. This reader has reached its "end of file" and will likely no longer
+ /// be able to produce bytes. Note that this does not mean that the
+ /// reader will *always* no longer be able to produce bytes.
+ /// 2. The buffer specified had a remaining capacity of zero.
///
/// # Errors
///
@@ -579,7 +586,7 @@ cfg_io_util! {
/// async fn main() -> io::Result<()> {
/// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]);
///
- /// assert_eq!(i64::min_value(), reader.read_i64().await?);
+ /// assert_eq!(i64::MIN, reader.read_i64().await?);
/// Ok(())
/// }
/// ```
@@ -659,7 +666,7 @@ cfg_io_util! {
/// 0, 0, 0, 0, 0, 0, 0, 0
/// ]);
///
- /// assert_eq!(i128::min_value(), reader.read_i128().await?);
+ /// assert_eq!(i128::MIN, reader.read_i128().await?);
/// Ok(())
/// }
/// ```
diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs
index 2510ccd..4651a99 100644
--- a/src/io/util/async_write_ext.rs
+++ b/src/io/util/async_write_ext.rs
@@ -621,8 +621,8 @@ cfg_io_util! {
/// async fn main() -> io::Result<()> {
/// let mut writer = Vec::new();
///
- /// writer.write_i64(i64::min_value()).await?;
- /// writer.write_i64(i64::max_value()).await?;
+ /// writer.write_i64(i64::MIN).await?;
+ /// writer.write_i64(i64::MAX).await?;
///
/// assert_eq!(writer, b"\x80\x00\x00\x00\x00\x00\x00\x00\x7f\xff\xff\xff\xff\xff\xff\xff");
/// Ok(())
@@ -699,7 +699,7 @@ cfg_io_util! {
/// async fn main() -> io::Result<()> {
/// let mut writer = Vec::new();
///
- /// writer.write_i128(i128::min_value()).await?;
+ /// writer.write_i128(i128::MIN).await?;
///
/// assert_eq!(writer, vec![
/// 0x80, 0, 0, 0, 0, 0, 0, 0,
@@ -930,8 +930,8 @@ cfg_io_util! {
/// async fn main() -> io::Result<()> {
/// let mut writer = Vec::new();
///
- /// writer.write_i64_le(i64::min_value()).await?;
- /// writer.write_i64_le(i64::max_value()).await?;
+ /// writer.write_i64_le(i64::MIN).await?;
+ /// writer.write_i64_le(i64::MAX).await?;
///
/// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f");
/// Ok(())
@@ -1008,7 +1008,7 @@ cfg_io_util! {
/// async fn main() -> io::Result<()> {
/// let mut writer = Vec::new();
///
- /// writer.write_i128_le(i128::min_value()).await?;
+ /// writer.write_i128_le(i128::MIN).await?;
///
/// assert_eq!(writer, vec![
/// 0, 0, 0, 0, 0, 0, 0,
diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs
index cc65ef2..c4d6842 100644
--- a/src/io/util/buf_reader.rs
+++ b/src/io/util/buf_reader.rs
@@ -198,7 +198,7 @@ impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> {
// it should be safe to assume that remainder fits within an i64 as the alternative
// means we managed to allocate 8 exbibytes and that's absurd.
// But it's not out of the realm of possibility for some weird underlying reader to
- // support seeking by i64::min_value() so we need to handle underflow when subtracting
+ // support seeking by i64::MIN so we need to handle underflow when subtracting
// remainder.
if let Some(offset) = n.checked_sub(remainder) {
self.as_mut()
diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs
index 9238665..ff3d9db 100644
--- a/src/io/util/buf_stream.rs
+++ b/src/io/util/buf_stream.rs
@@ -1,8 +1,8 @@
use crate::io::util::{BufReader, BufWriter};
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project;
-use std::io;
+use std::io::{self, SeekFrom};
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -146,6 +146,34 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> {
}
}
+/// Seek to an offset, in bytes, in the underlying stream.
+///
+/// The position used for seeking with `SeekFrom::Current(_)` is the
+/// position the underlying stream would be at if the `BufStream` had no
+/// internal buffer.
+///
+/// Seeking always discards the internal buffer, even if the seek position
+/// would otherwise fall within it. This guarantees that calling
+/// `.into_inner()` immediately after a seek yields the underlying reader
+/// at the same position.
+///
+/// See [`AsyncSeek`] for more details.
+///
+/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)`
+/// where `n` minus the internal buffer length overflows an `i64`, two
+/// seeks will be performed instead of one. If the second seek returns
+/// `Err`, the underlying reader will be left at the same position it would
+/// have if you called `seek` with `SeekFrom::Current(0)`.
+impl<RW: AsyncRead + AsyncWrite + AsyncSeek> AsyncSeek for BufStream<RW> {
+ fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
+ self.project().inner.start_seek(position)
+ }
+
+ fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
+ self.project().inner.poll_complete(cx)
+ }
+}
+
impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.project().inner.poll_fill_buf(cx)
diff --git a/src/lib.rs b/src/lib.rs
index 15aeced..c74b964 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -9,6 +9,7 @@
rust_2018_idioms,
unreachable_pub
)]
+#![deny(unused_must_use)]
#![cfg_attr(docsrs, deny(broken_intra_doc_links))]
#![doc(test(
no_crate_inject,
@@ -442,6 +443,28 @@ mod util;
/// ```
pub mod stream {}
+// local re-exports of platform specific things, allowing for decent
+// documentation to be shimmed in on docs.rs
+
+#[cfg(docsrs)]
+pub mod doc;
+
+#[cfg(docsrs)]
+#[allow(unused)]
+pub(crate) use self::doc::os;
+
+#[cfg(not(docsrs))]
+#[allow(unused)]
+pub(crate) use std::os;
+
+#[cfg(docsrs)]
+#[allow(unused)]
+pub(crate) use self::doc::winapi;
+
+#[cfg(all(not(docsrs), windows, feature = "net"))]
+#[allow(unused)]
+pub(crate) use ::winapi;
+
cfg_macros! {
/// Implementation detail of the `select!` macro. This macro is **not**
/// intended to be used as part of the public API and is permitted to
@@ -453,15 +476,20 @@ cfg_macros! {
#[cfg(feature = "rt-multi-thread")]
#[cfg(not(test))] // Work around for rust-lang/rust#62127
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
+ #[doc(inline)]
pub use tokio_macros::main;
#[cfg(feature = "rt-multi-thread")]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
+ #[doc(inline)]
pub use tokio_macros::test;
cfg_not_rt_multi_thread! {
#[cfg(not(test))] // Work around for rust-lang/rust#62127
+ #[doc(inline)]
pub use tokio_macros::main_rt as main;
+
+ #[doc(inline)]
pub use tokio_macros::test_rt as test;
}
}
@@ -469,7 +497,10 @@ cfg_macros! {
// Always fail if rt is not enabled.
cfg_not_rt! {
#[cfg(not(test))]
+ #[doc(inline)]
pub use tokio_macros::main_fail as main;
+
+ #[doc(inline)]
pub use tokio_macros::test_fail as test;
}
}
diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs
index 3442612..1e77556 100644
--- a/src/macros/cfg.rs
+++ b/src/macros/cfg.rs
@@ -157,7 +157,6 @@ macro_rules! cfg_macros {
$(
#[cfg(feature = "macros")]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
- #[doc(inline)]
$item
)*
}
@@ -183,6 +182,16 @@ macro_rules! cfg_net_unix {
}
}
+macro_rules! cfg_net_windows {
+ ($($item:item)*) => {
+ $(
+ #[cfg(all(any(docsrs, windows), feature = "net"))]
+ #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "net"))))]
+ $item
+ )*
+ }
+}
+
macro_rules! cfg_process {
($($item:item)*) => {
$(
diff --git a/src/macros/select.rs b/src/macros/select.rs
index f98ebff..371a3de 100644
--- a/src/macros/select.rs
+++ b/src/macros/select.rs
@@ -398,7 +398,7 @@ macro_rules! select {
// set the appropriate bit in `disabled`.
$(
if !$c {
- let mask = 1 << $crate::count!( $($skip)* );
+ let mask: util::Mask = 1 << $crate::count!( $($skip)* );
disabled |= mask;
}
)*
diff --git a/src/net/mod.rs b/src/net/mod.rs
index 2f17f9e..0b8c1ec 100644
--- a/src/net/mod.rs
+++ b/src/net/mod.rs
@@ -46,3 +46,7 @@ cfg_net_unix! {
pub use unix::listener::UnixListener;
pub use unix::stream::UnixStream;
}
+
+cfg_net_windows! {
+ pub mod windows;
+}
diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs
index 4bcbe3f..02cb637 100644
--- a/src/net/tcp/socket.rs
+++ b/src/net/tcp/socket.rs
@@ -482,6 +482,48 @@ impl TcpSocket {
let mio = self.inner.listen(backlog)?;
TcpListener::new(mio)
}
+
+ /// Converts a [`std::net::TcpStream`] into a `TcpSocket`. The provided
+ /// socket must not have been connected prior to calling this function. This
+ /// function is typically used together with crates such as [`socket2`] to
+ /// configure socket options that are not available on `TcpSocket`.
+ ///
+ /// [`std::net::TcpStream`]: struct@std::net::TcpStream
+ /// [`socket2`]: https://docs.rs/socket2/
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::net::TcpSocket;
+ /// use socket2::{Domain, Socket, Type};
+ ///
+ /// #[tokio::main]
+ /// async fn main() -> std::io::Result<()> {
+ ///
+ /// let socket2_socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
+ ///
+ /// let socket = TcpSocket::from_std_stream(socket2_socket.into());
+ ///
+ /// Ok(())
+ /// }
+ /// ```
+ pub fn from_std_stream(std_stream: std::net::TcpStream) -> TcpSocket {
+ #[cfg(unix)]
+ {
+ use std::os::unix::io::{FromRawFd, IntoRawFd};
+
+ let raw_fd = std_stream.into_raw_fd();
+ unsafe { TcpSocket::from_raw_fd(raw_fd) }
+ }
+
+ #[cfg(windows)]
+ {
+ use std::os::windows::io::{FromRawSocket, IntoRawSocket};
+
+ let raw_socket = std_stream.into_raw_socket();
+ unsafe { TcpSocket::from_raw_socket(raw_socket) }
+ }
+ }
}
impl fmt::Debug for TcpSocket {
diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs
index 5c7c198..b95a8f6 100644
--- a/src/net/unix/ucred.rs
+++ b/src/net/unix/ucred.rs
@@ -73,7 +73,7 @@ pub(crate) mod impl_linux {
// These paranoid checks should be optimized-out
assert!(mem::size_of::<u32>() <= mem::size_of::<usize>());
- assert!(ucred_size <= u32::max_value() as usize);
+ assert!(ucred_size <= u32::MAX as usize);
let mut ucred_size = ucred_size as socklen_t;
diff --git a/src/net/windows/mod.rs b/src/net/windows/mod.rs
new file mode 100644
index 0000000..060b68e
--- /dev/null
+++ b/src/net/windows/mod.rs
@@ -0,0 +1,3 @@
+//! Windows specific network types.
+
+pub mod named_pipe;
diff --git a/src/net/windows/named_pipe.rs b/src/net/windows/named_pipe.rs
new file mode 100644
index 0000000..8013d6f
--- /dev/null
+++ b/src/net/windows/named_pipe.rs
@@ -0,0 +1,1199 @@
+//! Tokio support for [Windows named pipes].
+//!
+//! [Windows named pipes]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
+
+use std::ffi::c_void;
+use std::ffi::OsStr;
+use std::io;
+use std::pin::Pin;
+use std::ptr;
+use std::task::{Context, Poll};
+
+use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf};
+use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle};
+
+// Hide imports which are not used when generating documentation.
+#[cfg(not(docsrs))]
+mod doc {
+ pub(super) use crate::os::windows::ffi::OsStrExt;
+ pub(super) use crate::winapi::shared::minwindef::{DWORD, FALSE};
+ pub(super) use crate::winapi::um::fileapi;
+ pub(super) use crate::winapi::um::handleapi;
+ pub(super) use crate::winapi::um::namedpipeapi;
+ pub(super) use crate::winapi::um::winbase;
+ pub(super) use crate::winapi::um::winnt;
+
+ pub(super) use mio::windows as mio_windows;
+}
+
+// NB: none of these shows up in public API, so don't document them.
+#[cfg(docsrs)]
+mod doc {
+ pub type DWORD = crate::doc::NotDefinedHere;
+
+ pub(super) mod mio_windows {
+ pub type NamedPipe = crate::doc::NotDefinedHere;
+ }
+}
+
+use self::doc::*;
+
+/// A [Windows named pipe] server.
+///
+/// Accepting client connections involves creating a server with
+/// [`ServerOptions::create`] and waiting for clients to connect using
+/// [`NamedPipeServer::connect`].
+///
+/// To avoid having clients sporadically fail with
+/// [`std::io::ErrorKind::NotFound`] when they connect to a server, we must
+/// ensure that at least one server instance is available at all times. This
+/// means that the typical listen loop for a server is a bit involved, because
+/// we have to ensure that we never drop a server accidentally while a client
+/// might connect.
+///
+/// So a correctly implemented server looks like this:
+///
+/// ```no_run
+/// use std::io;
+/// use tokio::net::windows::named_pipe::ServerOptions;
+///
+/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-server";
+///
+/// # #[tokio::main] async fn main() -> std::io::Result<()> {
+/// // The first server needs to be constructed early so that clients can
+/// // be correctly connected. Otherwise calling .wait will cause the client to
+/// // error.
+/// //
+/// // Here we also make use of `first_pipe_instance`, which will ensure that
+/// // there are no other servers up and running already.
+/// let mut server = ServerOptions::new()
+/// .first_pipe_instance(true)
+/// .create(PIPE_NAME)?;
+///
+/// // Spawn the server loop.
+/// let server = tokio::spawn(async move {
+/// loop {
+/// // Wait for a client to connect.
+/// let connected = server.connect().await?;
+///
+/// // Construct the next server to be connected before sending the one
+/// // we already have of onto a task. This ensures that the server
+/// // isn't closed (after it's done in the task) before a new one is
+/// // available. Otherwise the client might error with
+/// // `io::ErrorKind::NotFound`.
+/// server = ServerOptions::new().create(PIPE_NAME)?;
+///
+/// let client = tokio::spawn(async move {
+/// /* use the connected client */
+/// # Ok::<_, std::io::Error>(())
+/// });
+/// # if true { break } // needed for type inference to work
+/// }
+///
+/// Ok::<_, io::Error>(())
+/// });
+///
+/// /* do something else not server related here */
+/// # Ok(()) }
+/// ```
+///
+/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY
+/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
+#[derive(Debug)]
+pub struct NamedPipeServer {
+ io: PollEvented<mio_windows::NamedPipe>,
+}
+
+impl NamedPipeServer {
+ /// Construct a new named pipe server from the specified raw handle.
+ ///
+ /// This function will consume ownership of the handle given, passing
+ /// responsibility for closing the handle to the returned object.
+ ///
+ /// This function is also unsafe as the primitives currently returned have
+ /// the contract that they are the sole owner of the file descriptor they
+ /// are wrapping. Usage of this function could accidentally allow violating
+ /// this contract which can cause memory unsafety in code that relies on it
+ /// being true.
+ ///
+ /// # Errors
+ ///
+ /// This errors if called outside of a [Tokio Runtime], or in a runtime that
+ /// has not [enabled I/O], or if any OS-specific I/O errors occur.
+ ///
+ /// [Tokio Runtime]: crate::runtime::Runtime
+ /// [enabled I/O]: crate::runtime::Builder::enable_io
+ pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> {
+ let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle);
+
+ Ok(Self {
+ io: PollEvented::new(named_pipe)?,
+ })
+ }
+
+ /// Retrieves information about the named pipe the server is associated
+ /// with.
+ ///
+ /// ```no_run
+ /// use tokio::net::windows::named_pipe::{PipeEnd, PipeMode, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-info";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let server = ServerOptions::new()
+ /// .pipe_mode(PipeMode::Message)
+ /// .max_instances(5)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let server_info = server.info()?;
+ ///
+ /// assert_eq!(server_info.end, PipeEnd::Server);
+ /// assert_eq!(server_info.mode, PipeMode::Message);
+ /// assert_eq!(server_info.max_instances, 5);
+ /// # Ok(()) }
+ /// ```
+ pub fn info(&self) -> io::Result<PipeInfo> {
+ // Safety: we're ensuring the lifetime of the named pipe.
+ unsafe { named_pipe_info(self.io.as_raw_handle()) }
+ }
+
+ /// Enables a named pipe server process to wait for a client process to
+ /// connect to an instance of a named pipe. A client process connects by
+ /// creating a named pipe with the same name.
+ ///
+ /// This corresponds to the [`ConnectNamedPipe`] system call.
+ ///
+ /// [`ConnectNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-connectnamedpipe
+ ///
+ /// ```no_run
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let pipe = ServerOptions::new().create(PIPE_NAME)?;
+ ///
+ /// // Wait for a client to connect.
+ /// pipe.connect().await?;
+ ///
+ /// // Use the connected client...
+ /// # Ok(()) }
+ /// ```
+ pub async fn connect(&self) -> io::Result<()> {
+ loop {
+ match self.io.connect() {
+ Ok(()) => break,
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.io.registration().readiness(Interest::WRITABLE).await?;
+ }
+ Err(e) => return Err(e),
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Disconnects the server end of a named pipe instance from a client
+ /// process.
+ ///
+ /// ```
+ /// use tokio::io::AsyncWriteExt;
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ /// use winapi::shared::winerror;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-disconnect";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let server = ServerOptions::new()
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let mut client = ClientOptions::new()
+ /// .open(PIPE_NAME)?;
+ ///
+ /// // Wait for a client to become connected.
+ /// server.connect().await?;
+ ///
+ /// // Forcibly disconnect the client.
+ /// server.disconnect()?;
+ ///
+ /// // Write fails with an OS-specific error after client has been
+ /// // disconnected.
+ /// let e = client.write(b"ping").await.unwrap_err();
+ /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_NOT_CONNECTED as i32));
+ /// # Ok(()) }
+ /// ```
+ pub fn disconnect(&self) -> io::Result<()> {
+ self.io.disconnect()
+ }
+}
+
+impl AsyncRead for NamedPipeServer {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ unsafe { self.io.poll_read(cx, buf) }
+ }
+}
+
+impl AsyncWrite for NamedPipeServer {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.io.poll_write(cx, buf)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[io::IoSlice<'_>],
+ ) -> Poll<io::Result<usize>> {
+ self.io.poll_write_vectored(cx, bufs)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ self.poll_flush(cx)
+ }
+}
+
+impl AsRawHandle for NamedPipeServer {
+ fn as_raw_handle(&self) -> RawHandle {
+ self.io.as_raw_handle()
+ }
+}
+
+/// A [Windows named pipe] client.
+///
+/// Constructed using [`ClientOptions::open`].
+///
+/// Connecting a client correctly involves a few steps. When connecting through
+/// [`ClientOptions::open`], it might error indicating one of two things:
+///
+/// * [`std::io::ErrorKind::NotFound`] - There is no server available.
+/// * [`ERROR_PIPE_BUSY`] - There is a server available, but it is busy. Sleep
+/// for a while and try again.
+///
+/// So a correctly implemented client looks like this:
+///
+/// ```no_run
+/// use std::time::Duration;
+/// use tokio::net::windows::named_pipe::ClientOptions;
+/// use tokio::time;
+/// use winapi::shared::winerror;
+///
+/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-client";
+///
+/// # #[tokio::main] async fn main() -> std::io::Result<()> {
+/// let client = loop {
+/// match ClientOptions::new().open(PIPE_NAME) {
+/// Ok(client) => break client,
+/// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (),
+/// Err(e) => return Err(e),
+/// }
+///
+/// time::sleep(Duration::from_millis(50)).await;
+/// };
+///
+/// /* use the connected client */
+/// # Ok(()) }
+/// ```
+///
+/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY
+/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
+#[derive(Debug)]
+pub struct NamedPipeClient {
+ io: PollEvented<mio_windows::NamedPipe>,
+}
+
+impl NamedPipeClient {
+ /// Construct a new named pipe client from the specified raw handle.
+ ///
+ /// This function will consume ownership of the handle given, passing
+ /// responsibility for closing the handle to the returned object.
+ ///
+ /// This function is also unsafe as the primitives currently returned have
+ /// the contract that they are the sole owner of the file descriptor they
+ /// are wrapping. Usage of this function could accidentally allow violating
+ /// this contract which can cause memory unsafety in code that relies on it
+ /// being true.
+ ///
+ /// # Errors
+ ///
+ /// This errors if called outside of a [Tokio Runtime], or in a runtime that
+ /// has not [enabled I/O], or if any OS-specific I/O errors occur.
+ ///
+ /// [Tokio Runtime]: crate::runtime::Runtime
+ /// [enabled I/O]: crate::runtime::Builder::enable_io
+ pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> {
+ let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle);
+
+ Ok(Self {
+ io: PollEvented::new(named_pipe)?,
+ })
+ }
+
+ /// Retrieves information about the named pipe the client is associated
+ /// with.
+ ///
+ /// ```no_run
+ /// use tokio::net::windows::named_pipe::{ClientOptions, PipeEnd, PipeMode};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-info";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let client = ClientOptions::new()
+ /// .open(PIPE_NAME)?;
+ ///
+ /// let client_info = client.info()?;
+ ///
+ /// assert_eq!(client_info.end, PipeEnd::Client);
+ /// assert_eq!(client_info.mode, PipeMode::Message);
+ /// assert_eq!(client_info.max_instances, 5);
+ /// # Ok(()) }
+ /// ```
+ pub fn info(&self) -> io::Result<PipeInfo> {
+ // Safety: we're ensuring the lifetime of the named pipe.
+ unsafe { named_pipe_info(self.io.as_raw_handle()) }
+ }
+}
+
+impl AsyncRead for NamedPipeClient {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ unsafe { self.io.poll_read(cx, buf) }
+ }
+}
+
+impl AsyncWrite for NamedPipeClient {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ self.io.poll_write(cx, buf)
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[io::IoSlice<'_>],
+ ) -> Poll<io::Result<usize>> {
+ self.io.poll_write_vectored(cx, bufs)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ self.poll_flush(cx)
+ }
+}
+
+impl AsRawHandle for NamedPipeClient {
+ fn as_raw_handle(&self) -> RawHandle {
+ self.io.as_raw_handle()
+ }
+}
+
+// Helper to set a boolean flag as a bitfield.
+macro_rules! bool_flag {
+ ($f:expr, $t:expr, $flag:expr) => {{
+ let current = $f;
+
+ if $t {
+ $f = current | $flag;
+ } else {
+ $f = current & !$flag;
+ };
+ }};
+}
+
+/// A builder structure for construct a named pipe with named pipe-specific
+/// options. This is required to use for named pipe servers who wants to modify
+/// pipe-related options.
+///
+/// See [`ServerOptions::create`].
+#[derive(Debug, Clone)]
+pub struct ServerOptions {
+ open_mode: DWORD,
+ pipe_mode: DWORD,
+ max_instances: DWORD,
+ out_buffer_size: DWORD,
+ in_buffer_size: DWORD,
+ default_timeout: DWORD,
+}
+
+impl ServerOptions {
+ /// Creates a new named pipe builder with the default settings.
+ ///
+ /// ```
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-new";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let server = ServerOptions::new().create(PIPE_NAME)?;
+ /// # Ok(()) }
+ /// ```
+ pub fn new() -> ServerOptions {
+ ServerOptions {
+ open_mode: winbase::PIPE_ACCESS_DUPLEX | winbase::FILE_FLAG_OVERLAPPED,
+ pipe_mode: winbase::PIPE_TYPE_BYTE | winbase::PIPE_REJECT_REMOTE_CLIENTS,
+ max_instances: winbase::PIPE_UNLIMITED_INSTANCES,
+ out_buffer_size: 65536,
+ in_buffer_size: 65536,
+ default_timeout: 0,
+ }
+ }
+
+ /// The pipe mode.
+ ///
+ /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for
+ /// documentation of what each mode means.
+ ///
+ /// This corresponding to specifying [`dwPipeMode`].
+ ///
+ /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
+ pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self {
+ self.pipe_mode = match pipe_mode {
+ PipeMode::Byte => winbase::PIPE_TYPE_BYTE,
+ PipeMode::Message => winbase::PIPE_TYPE_MESSAGE,
+ };
+
+ self
+ }
+
+ /// The flow of data in the pipe goes from client to server only.
+ ///
+ /// This corresponds to setting [`PIPE_ACCESS_INBOUND`].
+ ///
+ /// [`PIPE_ACCESS_INBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_inbound
+ ///
+ /// # Errors
+ ///
+ /// Server side prevents connecting by denying inbound access, client errors
+ /// with [`std::io::ErrorKind::PermissionDenied`] when attempting to create
+ /// the connection.
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err1";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let _server = ServerOptions::new()
+ /// .access_inbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let e = ClientOptions::new()
+ /// .open(PIPE_NAME)
+ /// .unwrap_err();
+ ///
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// Disabling writing allows a client to connect, but errors with
+ /// [`std::io::ErrorKind::PermissionDenied`] if a write is attempted.
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::io::AsyncWriteExt;
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err2";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let server = ServerOptions::new()
+ /// .access_inbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let mut client = ClientOptions::new()
+ /// .write(false)
+ /// .open(PIPE_NAME)?;
+ ///
+ /// server.connect().await?;
+ ///
+ /// let e = client.write(b"ping").await.unwrap_err();
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// # Examples
+ ///
+ /// A unidirectional named pipe that only supports server-to-client
+ /// communication.
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::io::{AsyncReadExt, AsyncWriteExt};
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let mut server = ServerOptions::new()
+ /// .access_inbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let mut client = ClientOptions::new()
+ /// .write(false)
+ /// .open(PIPE_NAME)?;
+ ///
+ /// server.connect().await?;
+ ///
+ /// let write = server.write_all(b"ping");
+ ///
+ /// let mut buf = [0u8; 4];
+ /// let read = client.read_exact(&mut buf);
+ ///
+ /// let ((), read) = tokio::try_join!(write, read)?;
+ ///
+ /// assert_eq!(read, 4);
+ /// assert_eq!(&buf[..], b"ping");
+ /// # Ok(()) }
+ /// ```
+ pub fn access_inbound(&mut self, allowed: bool) -> &mut Self {
+ bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_INBOUND);
+ self
+ }
+
+ /// The flow of data in the pipe goes from server to client only.
+ ///
+ /// This corresponds to setting [`PIPE_ACCESS_OUTBOUND`].
+ ///
+ /// [`PIPE_ACCESS_OUTBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_outbound
+ ///
+ /// # Errors
+ ///
+ /// Server side prevents connecting by denying outbound access, client
+ /// errors with [`std::io::ErrorKind::PermissionDenied`] when attempting to
+ /// create the connection.
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err1";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let server = ServerOptions::new()
+ /// .access_outbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let e = ClientOptions::new()
+ /// .open(PIPE_NAME)
+ /// .unwrap_err();
+ ///
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// Disabling reading allows a client to connect, but attempting to read
+ /// will error with [`std::io::ErrorKind::PermissionDenied`].
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::io::AsyncReadExt;
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err2";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let server = ServerOptions::new()
+ /// .access_outbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let mut client = ClientOptions::new()
+ /// .read(false)
+ /// .open(PIPE_NAME)?;
+ ///
+ /// server.connect().await?;
+ ///
+ /// let mut buf = [0u8; 4];
+ /// let e = client.read(&mut buf).await.unwrap_err();
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// # Examples
+ ///
+ /// A unidirectional named pipe that only supports client-to-server
+ /// communication.
+ ///
+ /// ```
+ /// use tokio::io::{AsyncReadExt, AsyncWriteExt};
+ /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let mut server = ServerOptions::new()
+ /// .access_outbound(false)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// let mut client = ClientOptions::new()
+ /// .read(false)
+ /// .open(PIPE_NAME)?;
+ ///
+ /// server.connect().await?;
+ ///
+ /// let write = client.write_all(b"ping");
+ ///
+ /// let mut buf = [0u8; 4];
+ /// let read = server.read_exact(&mut buf);
+ ///
+ /// let ((), read) = tokio::try_join!(write, read)?;
+ ///
+ /// println!("done reading and writing");
+ ///
+ /// assert_eq!(read, 4);
+ /// assert_eq!(&buf[..], b"ping");
+ /// # Ok(()) }
+ /// ```
+ pub fn access_outbound(&mut self, allowed: bool) -> &mut Self {
+ bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_OUTBOUND);
+ self
+ }
+
+ /// If you attempt to create multiple instances of a pipe with this flag
+ /// set, creation of the first server instance succeeds, but creation of any
+ /// subsequent instances will fail with
+ /// [`std::io::ErrorKind::PermissionDenied`].
+ ///
+ /// This option is intended to be used with servers that want to ensure that
+ /// they are the only process listening for clients on a given named pipe.
+ /// This is accomplished by enabling it for the first server instance
+ /// created in a process.
+ ///
+ /// This corresponds to setting [`FILE_FLAG_FIRST_PIPE_INSTANCE`].
+ ///
+ /// # Errors
+ ///
+ /// If this option is set and more than one instance of the server for a
+ /// given named pipe exists, calling [`create`] will fail with
+ /// [`std::io::ErrorKind::PermissionDenied`].
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance-error";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let server1 = ServerOptions::new()
+ /// .first_pipe_instance(true)
+ /// .create(PIPE_NAME)?;
+ ///
+ /// // Second server errs, since it's not the first instance.
+ /// let e = ServerOptions::new()
+ /// .first_pipe_instance(true)
+ /// .create(PIPE_NAME)
+ /// .unwrap_err();
+ ///
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let mut builder = ServerOptions::new();
+ /// builder.first_pipe_instance(true);
+ ///
+ /// let server = builder.create(PIPE_NAME)?;
+ /// let e = builder.create(PIPE_NAME).unwrap_err();
+ /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+ /// drop(server);
+ ///
+ /// // OK: since, we've closed the other instance.
+ /// let _server2 = builder.create(PIPE_NAME)?;
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// [`create`]: ServerOptions::create
+ /// [`FILE_FLAG_FIRST_PIPE_INSTANCE`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_first_pipe_instance
+ pub fn first_pipe_instance(&mut self, first: bool) -> &mut Self {
+ bool_flag!(
+ self.open_mode,
+ first,
+ winbase::FILE_FLAG_FIRST_PIPE_INSTANCE
+ );
+ self
+ }
+
+ /// Indicates whether this server can accept remote clients or not. Remote
+ /// clients are disabled by default.
+ ///
+ /// This corresponds to setting [`PIPE_REJECT_REMOTE_CLIENTS`].
+ ///
+ /// [`PIPE_REJECT_REMOTE_CLIENTS`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_reject_remote_clients
+ pub fn reject_remote_clients(&mut self, reject: bool) -> &mut Self {
+ bool_flag!(self.pipe_mode, reject, winbase::PIPE_REJECT_REMOTE_CLIENTS);
+ self
+ }
+
+ /// The maximum number of instances that can be created for this pipe. The
+ /// first instance of the pipe can specify this value; the same number must
+ /// be specified for other instances of the pipe. Acceptable values are in
+ /// the range 1 through 254. The default value is unlimited.
+ ///
+ /// This corresponds to specifying [`nMaxInstances`].
+ ///
+ /// [`nMaxInstances`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
+ ///
+ /// # Errors
+ ///
+ /// The same numbers of `max_instances` have to be used by all servers. Any
+ /// additional servers trying to be built which uses a mismatching value
+ /// might error.
+ ///
+ /// ```
+ /// use std::io;
+ /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions};
+ /// use winapi::shared::winerror;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-max-instances";
+ ///
+ /// # #[tokio::main] async fn main() -> io::Result<()> {
+ /// let mut server = ServerOptions::new();
+ /// server.max_instances(2);
+ ///
+ /// let s1 = server.create(PIPE_NAME)?;
+ /// let c1 = ClientOptions::new().open(PIPE_NAME);
+ ///
+ /// let s2 = server.create(PIPE_NAME)?;
+ /// let c2 = ClientOptions::new().open(PIPE_NAME);
+ ///
+ /// // Too many servers!
+ /// let e = server.create(PIPE_NAME).unwrap_err();
+ /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32));
+ ///
+ /// // Still too many servers even if we specify a higher value!
+ /// let e = server.max_instances(100).create(PIPE_NAME).unwrap_err();
+ /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32));
+ /// # Ok(()) }
+ /// ```
+ ///
+ /// # Panics
+ ///
+ /// This function will panic if more than 254 instances are specified. If
+ /// you do not wish to set an instance limit, leave it unspecified.
+ ///
+ /// ```should_panic
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let builder = ServerOptions::new().max_instances(255);
+ /// # Ok(()) }
+ /// ```
+ pub fn max_instances(&mut self, instances: usize) -> &mut Self {
+ assert!(instances < 255, "cannot specify more than 254 instances");
+ self.max_instances = instances as DWORD;
+ self
+ }
+
+ /// The number of bytes to reserve for the output buffer.
+ ///
+ /// This corresponds to specifying [`nOutBufferSize`].
+ ///
+ /// [`nOutBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
+ pub fn out_buffer_size(&mut self, buffer: u32) -> &mut Self {
+ self.out_buffer_size = buffer as DWORD;
+ self
+ }
+
+ /// The number of bytes to reserve for the input buffer.
+ ///
+ /// This corresponds to specifying [`nInBufferSize`].
+ ///
+ /// [`nInBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
+ pub fn in_buffer_size(&mut self, buffer: u32) -> &mut Self {
+ self.in_buffer_size = buffer as DWORD;
+ self
+ }
+
+ /// Create the named pipe identified by `addr` for use as a server.
+ ///
+ /// This uses the [`CreateNamedPipe`] function.
+ ///
+ /// [`CreateNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
+ ///
+ /// # Errors
+ ///
+ /// This errors if called outside of a [Tokio Runtime], or in a runtime that
+ /// has not [enabled I/O], or if any OS-specific I/O errors occur.
+ ///
+ /// [Tokio Runtime]: crate::runtime::Runtime
+ /// [enabled I/O]: crate::runtime::Builder::enable_io
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::net::windows::named_pipe::ServerOptions;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-create";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let server = ServerOptions::new().create(PIPE_NAME)?;
+ /// # Ok(()) }
+ /// ```
+ pub fn create(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeServer> {
+ // Safety: We're calling create_with_security_attributes_raw w/ a null
+ // pointer which disables it.
+ unsafe { self.create_with_security_attributes_raw(addr, ptr::null_mut()) }
+ }
+
+ /// Create the named pipe identified by `addr` for use as a server.
+ ///
+ /// This is the same as [`create`] except that it supports providing the raw
+ /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed
+ /// as the `lpSecurityAttributes` argument to [`CreateFile`].
+ ///
+ /// # Errors
+ ///
+ /// This errors if called outside of a [Tokio Runtime], or in a runtime that
+ /// has not [enabled I/O], or if any OS-specific I/O errors occur.
+ ///
+ /// [Tokio Runtime]: crate::runtime::Runtime
+ /// [enabled I/O]: crate::runtime::Builder::enable_io
+ ///
+ /// # Safety
+ ///
+ /// The `attrs` argument must either be null or point at a valid instance of
+ /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the
+ /// behavior is identical to calling the [`create`] method.
+ ///
+ /// [`create`]: ServerOptions::create
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
+ /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES
+ pub unsafe fn create_with_security_attributes_raw(
+ &self,
+ addr: impl AsRef<OsStr>,
+ attrs: *mut c_void,
+ ) -> io::Result<NamedPipeServer> {
+ let addr = encode_addr(addr);
+
+ let h = namedpipeapi::CreateNamedPipeW(
+ addr.as_ptr(),
+ self.open_mode,
+ self.pipe_mode,
+ self.max_instances,
+ self.out_buffer_size,
+ self.in_buffer_size,
+ self.default_timeout,
+ attrs as *mut _,
+ );
+
+ if h == handleapi::INVALID_HANDLE_VALUE {
+ return Err(io::Error::last_os_error());
+ }
+
+ NamedPipeServer::from_raw_handle(h)
+ }
+}
+
+/// A builder suitable for building and interacting with named pipes from the
+/// client side.
+///
+/// See [`ClientOptions::open`].
+#[derive(Debug, Clone)]
+pub struct ClientOptions {
+ desired_access: DWORD,
+ security_qos_flags: DWORD,
+}
+
+impl ClientOptions {
+ /// Creates a new named pipe builder with the default settings.
+ ///
+ /// ```
+ /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions};
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-new";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// // Server must be created in order for the client creation to succeed.
+ /// let server = ServerOptions::new().create(PIPE_NAME)?;
+ /// let client = ClientOptions::new().open(PIPE_NAME)?;
+ /// # Ok(()) }
+ /// ```
+ pub fn new() -> Self {
+ Self {
+ desired_access: winnt::GENERIC_READ | winnt::GENERIC_WRITE,
+ security_qos_flags: winbase::SECURITY_IDENTIFICATION | winbase::SECURITY_SQOS_PRESENT,
+ }
+ }
+
+ /// If the client supports reading data. This is enabled by default.
+ ///
+ /// This corresponds to setting [`GENERIC_READ`] in the call to [`CreateFile`].
+ ///
+ /// [`GENERIC_READ`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
+ pub fn read(&mut self, allowed: bool) -> &mut Self {
+ bool_flag!(self.desired_access, allowed, winnt::GENERIC_READ);
+ self
+ }
+
+ /// If the created pipe supports writing data. This is enabled by default.
+ ///
+ /// This corresponds to setting [`GENERIC_WRITE`] in the call to [`CreateFile`].
+ ///
+ /// [`GENERIC_WRITE`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
+ pub fn write(&mut self, allowed: bool) -> &mut Self {
+ bool_flag!(self.desired_access, allowed, winnt::GENERIC_WRITE);
+ self
+ }
+
+ /// Sets qos flags which are combined with other flags and attributes in the
+ /// call to [`CreateFile`].
+ ///
+ /// By default `security_qos_flags` is set to [`SECURITY_IDENTIFICATION`],
+ /// calling this function would override that value completely with the
+ /// argument specified.
+ ///
+ /// When `security_qos_flags` is not set, a malicious program can gain the
+ /// elevated privileges of a privileged Rust process when it allows opening
+ /// user-specified paths, by tricking it into opening a named pipe. So
+ /// arguably `security_qos_flags` should also be set when opening arbitrary
+ /// paths. However the bits can then conflict with other flags, specifically
+ /// `FILE_FLAG_OPEN_NO_RECALL`.
+ ///
+ /// For information about possible values, see [Impersonation Levels] on the
+ /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set
+ /// automatically when using this method.
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ /// [`SECURITY_IDENTIFICATION`]: crate::winapi::um::winbase::SECURITY_IDENTIFICATION
+ /// [Impersonation Levels]: https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level
+ pub fn security_qos_flags(&mut self, flags: u32) -> &mut Self {
+ // See: https://github.com/rust-lang/rust/pull/58216
+ self.security_qos_flags = flags | winbase::SECURITY_SQOS_PRESENT;
+ self
+ }
+
+ /// Open the named pipe identified by `addr`.
+ ///
+ /// This opens the client using [`CreateFile`] with the
+ /// `dwCreationDisposition` option set to `OPEN_EXISTING`.
+ ///
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea
+ ///
+ /// # Errors
+ ///
+ /// This errors if called outside of a [Tokio Runtime], or in a runtime that
+ /// has not [enabled I/O], or if any OS-specific I/O errors occur.
+ ///
+ /// There are a few errors you need to take into account when creating a
+ /// named pipe on the client side:
+ ///
+ /// * [`std::io::ErrorKind::NotFound`] - This indicates that the named pipe
+ /// does not exist. Presumably the server is not up.
+ /// * [`ERROR_PIPE_BUSY`] - This error is raised when the named pipe exists,
+ /// but the server is not currently waiting for a connection. Please see the
+ /// examples for how to check for this error.
+ ///
+ /// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY
+ /// [`winapi`]: crate::winapi
+ /// [enabled I/O]: crate::runtime::Builder::enable_io
+ /// [Tokio Runtime]: crate::runtime::Runtime
+ ///
+ /// A connect loop that waits until a socket becomes available looks like
+ /// this:
+ ///
+ /// ```no_run
+ /// use std::time::Duration;
+ /// use tokio::net::windows::named_pipe::ClientOptions;
+ /// use tokio::time;
+ /// use winapi::shared::winerror;
+ ///
+ /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe";
+ ///
+ /// # #[tokio::main] async fn main() -> std::io::Result<()> {
+ /// let client = loop {
+ /// match ClientOptions::new().open(PIPE_NAME) {
+ /// Ok(client) => break client,
+ /// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (),
+ /// Err(e) => return Err(e),
+ /// }
+ ///
+ /// time::sleep(Duration::from_millis(50)).await;
+ /// };
+ ///
+ /// // use the connected client.
+ /// # Ok(()) }
+ /// ```
+ pub fn open(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeClient> {
+ // Safety: We're calling open_with_security_attributes_raw w/ a null
+ // pointer which disables it.
+ unsafe { self.open_with_security_attributes_raw(addr, ptr::null_mut()) }
+ }
+
+ /// Open the named pipe identified by `addr`.
+ ///
+ /// This is the same as [`open`] except that it supports providing the raw
+ /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed
+ /// as the `lpSecurityAttributes` argument to [`CreateFile`].
+ ///
+ /// # Safety
+ ///
+ /// The `attrs` argument must either be null or point at a valid instance of
+ /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the
+ /// behavior is identical to calling the [`open`] method.
+ ///
+ /// [`open`]: ClientOptions::open
+ /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
+ /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES
+ pub unsafe fn open_with_security_attributes_raw(
+ &self,
+ addr: impl AsRef<OsStr>,
+ attrs: *mut c_void,
+ ) -> io::Result<NamedPipeClient> {
+ let addr = encode_addr(addr);
+
+ // NB: We could use a platform specialized `OpenOptions` here, but since
+ // we have access to winapi it ultimately doesn't hurt to use
+ // `CreateFile` explicitly since it allows the use of our already
+ // well-structured wide `addr` to pass into CreateFileW.
+ let h = fileapi::CreateFileW(
+ addr.as_ptr(),
+ self.desired_access,
+ 0,
+ attrs as *mut _,
+ fileapi::OPEN_EXISTING,
+ self.get_flags(),
+ ptr::null_mut(),
+ );
+
+ if h == handleapi::INVALID_HANDLE_VALUE {
+ return Err(io::Error::last_os_error());
+ }
+
+ NamedPipeClient::from_raw_handle(h)
+ }
+
+ fn get_flags(&self) -> u32 {
+ self.security_qos_flags | winbase::FILE_FLAG_OVERLAPPED
+ }
+}
+
+/// The pipe mode of a named pipe.
+///
+/// Set through [`ServerOptions::pipe_mode`].
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[non_exhaustive]
+pub enum PipeMode {
+ /// Data is written to the pipe as a stream of bytes. The pipe does not
+ /// distinguish bytes written during different write operations.
+ ///
+ /// Corresponds to [`PIPE_TYPE_BYTE`][crate::winapi::um::winbase::PIPE_TYPE_BYTE].
+ Byte,
+ /// Data is written to the pipe as a stream of messages. The pipe treats the
+ /// bytes written during each write operation as a message unit. Any reading
+ /// on a named pipe returns [`ERROR_MORE_DATA`] when a message is not read
+ /// completely.
+ ///
+ /// Corresponds to [`PIPE_TYPE_MESSAGE`][crate::winapi::um::winbase::PIPE_TYPE_MESSAGE].
+ ///
+ /// [`ERROR_MORE_DATA`]: crate::winapi::shared::winerror::ERROR_MORE_DATA
+ Message,
+}
+
+/// Indicates the end of a named pipe.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[non_exhaustive]
+pub enum PipeEnd {
+ /// The named pipe refers to the client end of a named pipe instance.
+ ///
+ /// Corresponds to [`PIPE_CLIENT_END`][crate::winapi::um::winbase::PIPE_CLIENT_END].
+ Client,
+ /// The named pipe refers to the server end of a named pipe instance.
+ ///
+ /// Corresponds to [`PIPE_SERVER_END`][crate::winapi::um::winbase::PIPE_SERVER_END].
+ Server,
+}
+
+/// Information about a named pipe.
+///
+/// Constructed through [`NamedPipeServer::info`] or [`NamedPipeClient::info`].
+#[derive(Debug)]
+#[non_exhaustive]
+pub struct PipeInfo {
+ /// Indicates the mode of a named pipe.
+ pub mode: PipeMode,
+ /// Indicates the end of a named pipe.
+ pub end: PipeEnd,
+ /// The maximum number of instances that can be created for this pipe.
+ pub max_instances: u32,
+ /// The number of bytes to reserve for the output buffer.
+ pub out_buffer_size: u32,
+ /// The number of bytes to reserve for the input buffer.
+ pub in_buffer_size: u32,
+}
+
+/// Encode an address so that it is a null-terminated wide string.
+fn encode_addr(addr: impl AsRef<OsStr>) -> Box<[u16]> {
+ let len = addr.as_ref().encode_wide().count();
+ let mut vec = Vec::with_capacity(len + 1);
+ vec.extend(addr.as_ref().encode_wide());
+ vec.push(0);
+ vec.into_boxed_slice()
+}
+
+/// Internal function to get the info out of a raw named pipe.
+unsafe fn named_pipe_info(handle: RawHandle) -> io::Result<PipeInfo> {
+ let mut flags = 0;
+ let mut out_buffer_size = 0;
+ let mut in_buffer_size = 0;
+ let mut max_instances = 0;
+
+ let result = namedpipeapi::GetNamedPipeInfo(
+ handle,
+ &mut flags,
+ &mut out_buffer_size,
+ &mut in_buffer_size,
+ &mut max_instances,
+ );
+
+ if result == FALSE {
+ return Err(io::Error::last_os_error());
+ }
+
+ let mut end = PipeEnd::Client;
+ let mut mode = PipeMode::Byte;
+
+ if flags & winbase::PIPE_SERVER_END != 0 {
+ end = PipeEnd::Server;
+ }
+
+ if flags & winbase::PIPE_TYPE_MESSAGE != 0 {
+ mode = PipeMode::Message;
+ }
+
+ Ok(PipeInfo {
+ end,
+ mode,
+ out_buffer_size,
+ in_buffer_size,
+ max_instances,
+ })
+}
diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs
index ffe0bca..13dfb69 100644
--- a/src/runtime/basic_scheduler.rs
+++ b/src/runtime/basic_scheduler.rs
@@ -84,13 +84,13 @@ unsafe impl Send for Entry {}
/// Scheduler state shared between threads.
struct Shared {
- /// Remote run queue
- queue: Mutex<VecDeque<Entry>>,
+ /// Remote run queue. None if the `Runtime` has been dropped.
+ queue: Mutex<Option<VecDeque<Entry>>>,
- /// Unpark the blocked thread
+ /// Unpark the blocked thread.
unpark: Box<dyn Unpark>,
- // indicates whether the blocked on thread was woken
+ /// Indicates whether the blocked on thread was woken.
woken: AtomicBool,
}
@@ -124,7 +124,7 @@ impl<P: Park> BasicScheduler<P> {
let spawner = Spawner {
shared: Arc::new(Shared {
- queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
+ queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
unpark: unpark as Box<dyn Unpark>,
woken: AtomicBool::new(false),
}),
@@ -351,18 +351,29 @@ impl<P: Park> Drop for BasicScheduler<P> {
task.shutdown();
}
- // Drain remote queue
- for entry in scheduler.spawner.shared.queue.lock().drain(..) {
- match entry {
- Entry::Schedule(task) => {
- task.shutdown();
- }
- Entry::Release(..) => {
- // Do nothing, each entry in the linked list was *just*
- // dropped by the scheduler above.
+ // Drain remote queue and set it to None
+ let mut remote_queue = scheduler.spawner.shared.queue.lock();
+
+ // Using `Option::take` to replace the shared queue with `None`.
+ if let Some(remote_queue) = remote_queue.take() {
+ for entry in remote_queue {
+ match entry {
+ Entry::Schedule(task) => {
+ task.shutdown();
+ }
+ Entry::Release(..) => {
+ // Do nothing, each entry in the linked list was *just*
+ // dropped by the scheduler above.
+ }
}
}
}
+ // By dropping the mutex lock after the full duration of the above loop,
+ // any thread that sees the queue in the `None` state is guaranteed that
+ // the runtime has fully shut down.
+ //
+ // The assert below is unrelated to this mutex.
+ drop(remote_queue);
assert!(context.tasks.borrow().owned.is_empty());
});
@@ -381,7 +392,7 @@ impl Spawner {
/// Spawns a future onto the thread pool
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
- F: Future + Send + 'static,
+ F: crate::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let (task, handle) = task::joinable(future);
@@ -390,7 +401,10 @@ impl Spawner {
}
fn pop(&self) -> Option<Entry> {
- self.shared.queue.lock().pop_front()
+ match self.shared.queue.lock().as_mut() {
+ Some(queue) => queue.pop_front(),
+ None => None,
+ }
}
fn waker_ref(&self) -> WakerRef<'_> {
@@ -429,7 +443,19 @@ impl Schedule for Arc<Shared> {
// safety: the task is inserted in the list in `bind`.
unsafe { cx.tasks.borrow_mut().owned.remove(ptr) }
} else {
- self.queue.lock().push_back(Entry::Release(ptr));
+ // By sending an `Entry::Release` to the runtime, we ask the
+ // runtime to remove this task from the linked list in
+ // `Tasks::owned`.
+ //
+ // If the queue is `None`, then the task was already removed
+ // from that list in the destructor of `BasicScheduler`. We do
+ // not do anything in this case for the same reason that
+ // `Entry::Release` messages are ignored in the remote queue
+ // drain loop of `BasicScheduler`'s destructor.
+ if let Some(queue) = self.queue.lock().as_mut() {
+ queue.push_back(Entry::Release(ptr));
+ }
+
self.unpark.unpark();
// Returning `None` here prevents the task plumbing from being
// freed. It is then up to the scheduler through the queue we
@@ -445,8 +471,17 @@ impl Schedule for Arc<Shared> {
cx.tasks.borrow_mut().queue.push_back(task);
}
_ => {
- self.queue.lock().push_back(Entry::Schedule(task));
- self.unpark.unpark();
+ let mut guard = self.queue.lock();
+ if let Some(queue) = guard.as_mut() {
+ queue.push_back(Entry::Schedule(task));
+ drop(guard);
+ self.unpark.unpark();
+ } else {
+ // The runtime has shut down. We drop the new task
+ // immediately.
+ drop(guard);
+ task.shutdown();
+ }
}
});
}
diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs
index 5c9b8ed..b7d7251 100644
--- a/src/runtime/blocking/pool.rs
+++ b/src/runtime/blocking/pool.rs
@@ -4,7 +4,6 @@ use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::loom::thread;
use crate::runtime::blocking::schedule::NoopSchedule;
use crate::runtime::blocking::shutdown;
-use crate::runtime::blocking::task::BlockingTask;
use crate::runtime::builder::ThreadNameFn;
use crate::runtime::context;
use crate::runtime::task::{self, JoinHandle};
@@ -86,18 +85,6 @@ where
rt.spawn_blocking(func)
}
-#[allow(dead_code)]
-pub(crate) fn try_spawn_blocking<F, R>(func: F) -> Result<(), ()>
-where
- F: FnOnce() -> R + Send + 'static,
- R: Send + 'static,
-{
- let rt = context::current().expect(CONTEXT_MISSING_ERROR);
-
- let (task, _handle) = task::joinable(BlockingTask::new(func));
- rt.blocking_spawner.spawn(task, &rt)
-}
-
// ===== impl BlockingPool =====
impl BlockingPool {
diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs
index 4f1b4c5..173f0ca 100644
--- a/src/runtime/handle.rs
+++ b/src/runtime/handle.rs
@@ -174,8 +174,11 @@ impl Handle {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
+ let fut = BlockingTask::new(func);
+
#[cfg(all(tokio_unstable, feature = "tracing"))]
- let func = {
+ let fut = {
+ use tracing::Instrument;
#[cfg(tokio_track_caller)]
let location = std::panic::Location::caller();
#[cfg(tokio_track_caller)]
@@ -193,12 +196,9 @@ impl Handle {
kind = %"blocking",
function = %std::any::type_name::<F>(),
);
- move || {
- let _g = span.enter();
- func()
- }
+ fut.instrument(span)
};
- let (task, handle) = task::joinable(BlockingTask::new(func));
+ let (task, handle) = task::joinable(fut);
let _ = self.blocking_spawner.spawn(task, &self);
handle
}
diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs
index 6ea23c9..3df7bba 100644
--- a/src/runtime/queue.rs
+++ b/src/runtime/queue.rs
@@ -109,7 +109,10 @@ impl<T> Local<T> {
}
/// Pushes a task to the back of the local queue, skipping the LIFO slot.
- pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) {
+ pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>)
+ where
+ T: crate::runtime::task::Schedule,
+ {
let tail = loop {
let head = self.inner.head.load(Acquire);
let (steal, real) = unpack(head);
@@ -121,9 +124,14 @@ impl<T> Local<T> {
// There is capacity for the task
break tail;
} else if steal != real {
- // Concurrently stealing, this will free up capacity, so
- // only push the new task onto the inject queue
- inject.push(task);
+ // Concurrently stealing, this will free up capacity, so only
+ // push the new task onto the inject queue
+ //
+ // If the task failes to be pushed on the injection queue, there
+ // is nothing to be done at this point as the task cannot be a
+ // newly spawned task. Shutting down this task is handled by the
+ // worker shutdown process.
+ let _ = inject.push(task);
return;
} else {
// Push the current task and half of the queue into the
@@ -504,16 +512,19 @@ impl<T: 'static> Inject<T> {
}
/// Pushes a value into the queue.
- pub(super) fn push(&self, task: task::Notified<T>) {
+ ///
+ /// Returns `Err(task)` if pushing fails due to the queue being shutdown.
+ /// The caller is expected to call `shutdown()` on the task **if and only
+ /// if** it is a newly spawned task.
+ pub(super) fn push(&self, task: task::Notified<T>) -> Result<(), task::Notified<T>>
+ where
+ T: crate::runtime::task::Schedule,
+ {
// Acquire queue lock
let mut p = self.pointers.lock();
if p.is_closed {
- // Drop the mutex to avoid a potential deadlock when
- // re-entering.
- drop(p);
- drop(task);
- return;
+ return Err(task);
}
// safety: only mutated with the lock held
@@ -532,6 +543,7 @@ impl<T: 'static> Inject<T> {
p.tail = Some(task);
self.len.store(len + 1, Release);
+ Ok(())
}
pub(super) fn push_batch(
@@ -617,7 +629,7 @@ fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) {
/// Split the head value into the real head and the index a stealer is working
/// on.
fn unpack(n: u32) -> (u16, u16) {
- let real = n & u16::max_value() as u32;
+ let real = n & u16::MAX as u32;
let steal = n >> 16;
(steal as u16, real as u16)
@@ -630,5 +642,5 @@ fn pack(steal: u16, real: u16) -> u32 {
#[test]
fn test_local_queue_capacity() {
- assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::max_value() as usize);
+ assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
}
diff --git a/src/runtime/spawner.rs b/src/runtime/spawner.rs
index a37c667..fbcde2c 100644
--- a/src/runtime/spawner.rs
+++ b/src/runtime/spawner.rs
@@ -1,8 +1,7 @@
cfg_rt! {
+ use crate::future::Future;
use crate::runtime::basic_scheduler;
use crate::task::JoinHandle;
-
- use std::future::Future;
}
cfg_rt_multi_thread! {
diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs
index fb6dafd..026a6dc 100644
--- a/src/runtime/task/core.rs
+++ b/src/runtime/task/core.rs
@@ -9,13 +9,13 @@
//! Make sure to consult the relevant safety section of each function before
//! use.
+use crate::future::Future;
use crate::loom::cell::UnsafeCell;
use crate::runtime::task::raw::{self, Vtable};
use crate::runtime::task::state::State;
use crate::runtime::task::{Notified, Schedule, Task};
use crate::util::linked_list;
-use std::future::Future;
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Context, Poll, Waker};
@@ -71,6 +71,10 @@ pub(crate) struct Header {
/// Table of function pointers for executing actions on the task.
pub(super) vtable: &'static Vtable,
+
+ /// The tracing ID for this instrumented task.
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ pub(super) id: Option<tracing::Id>,
}
unsafe impl Send for Header {}
@@ -93,6 +97,8 @@ impl<T: Future, S: Schedule> Cell<T, S> {
/// Allocates a new task cell, containing the header, trailer, and core
/// structures.
pub(super) fn new(future: T, state: State) -> Box<Cell<T, S>> {
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ let id = future.id();
Box::new(Cell {
header: Header {
state,
@@ -100,6 +106,8 @@ impl<T: Future, S: Schedule> Cell<T, S> {
queue_next: UnsafeCell::new(None),
stack_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ id,
},
core: Core {
scheduler: Scheduler {
diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs
index 7d596e3..47bbcc1 100644
--- a/src/runtime/task/harness.rs
+++ b/src/runtime/task/harness.rs
@@ -1,9 +1,9 @@
+use crate::future::Future;
use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer};
use crate::runtime::task::state::Snapshot;
use crate::runtime::task::waker::waker_ref;
use crate::runtime::task::{JoinError, Notified, Schedule, Task};
-use std::future::Future;
use std::mem;
use std::panic;
use std::ptr::NonNull;
@@ -146,6 +146,11 @@ where
}
}
+ #[cfg(all(tokio_unstable, feature = "tracing"))]
+ pub(super) fn id(&self) -> Option<&tracing::Id> {
+ self.header().id.as_ref()
+ }
+
/// Forcibly shutdown the task
///
/// Attempt to transition to `Running` in order to forcibly shutdown the
diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs
index 7b49e95..58b8c2a 100644
--- a/src/runtime/task/mod.rs
+++ b/src/runtime/task/mod.rs
@@ -26,9 +26,9 @@ cfg_rt_multi_thread! {
pub(crate) use self::stack::TransferStack;
}
+use crate::future::Future;
use crate::util::linked_list;
-use std::future::Future;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::{fmt, mem};
diff --git a/src/runtime/task/raw.rs b/src/runtime/task/raw.rs
index cae56d0..a9cd4e6 100644
--- a/src/runtime/task/raw.rs
+++ b/src/runtime/task/raw.rs
@@ -1,6 +1,6 @@
+use crate::future::Future;
use crate::runtime::task::{Cell, Harness, Header, Schedule, State};
-use std::future::Future;
use std::ptr::NonNull;
use std::task::{Poll, Waker};
diff --git a/src/runtime/task/state.rs b/src/runtime/task/state.rs
index 21e9043..1f08d6d 100644
--- a/src/runtime/task/state.rs
+++ b/src/runtime/task/state.rs
@@ -29,12 +29,15 @@ const LIFECYCLE_MASK: usize = 0b11;
const NOTIFIED: usize = 0b100;
/// The join handle is still around
+#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556
const JOIN_INTEREST: usize = 0b1_000;
/// A join handle waker has been set
+#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556
const JOIN_WAKER: usize = 0b10_000;
/// The task has been forcibly cancelled.
+#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556
const CANCELLED: usize = 0b100_000;
/// All bits
@@ -52,7 +55,7 @@ const REF_ONE: usize = 1 << REF_COUNT_SHIFT;
/// State a task is initialized with
///
/// A task is initialized with two references: one for the scheduler and one for
-/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTERST` is
+/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is
/// set. A new task is immediately pushed into the run queue for execution and
/// starts with the `NOTIFIED` flag set.
const INITIAL_STATE: usize = (REF_ONE * 2) | JOIN_INTEREST | NOTIFIED;
@@ -64,7 +67,7 @@ impl State {
pub(super) fn new() -> State {
// A task is initialized with three references: one for the scheduler,
// one for the `JoinHandle`, one for the task handle made available in
- // release. As the task starts with a `JoinHandle`, `JOIN_INTERST` is
+ // release. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is
// set. A new task is immediately pushed into the run queue for
// execution and starts with the `NOTIFIED` flag set.
State {
@@ -306,7 +309,7 @@ impl State {
let prev = self.val.fetch_add(REF_ONE, Relaxed);
// If the reference count overflowed, abort.
- if prev > isize::max_value() as usize {
+ if prev > isize::MAX as usize {
process::abort();
}
}
@@ -410,7 +413,7 @@ impl Snapshot {
}
fn ref_inc(&mut self) {
- assert!(self.0 <= isize::max_value() as usize);
+ assert!(self.0 <= isize::MAX as usize);
self.0 += REF_ONE;
}
diff --git a/src/runtime/task/waker.rs b/src/runtime/task/waker.rs
index 5c2d478..b7313b4 100644
--- a/src/runtime/task/waker.rs
+++ b/src/runtime/task/waker.rs
@@ -1,7 +1,7 @@
+use crate::future::Future;
use crate::runtime::task::harness::Harness;
use crate::runtime::task::{Header, Schedule};
-use std::future::Future;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops;
@@ -44,12 +44,38 @@ impl<S> ops::Deref for WakerRef<'_, S> {
}
}
+cfg_trace! {
+ macro_rules! trace {
+ ($harness:expr, $op:expr) => {
+ if let Some(id) = $harness.id() {
+ tracing::trace!(
+ target: "tokio::task::waker",
+ op = $op,
+ task.id = id.into_u64(),
+ );
+ }
+ }
+ }
+}
+
+cfg_not_trace! {
+ macro_rules! trace {
+ ($harness:expr, $op:expr) => {
+ // noop
+ let _ = &$harness;
+ }
+ }
+}
+
unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker
where
T: Future,
S: Schedule,
{
let header = ptr as *const Header;
+ let ptr = NonNull::new_unchecked(ptr as *mut Header);
+ let harness = Harness::<T, S>::from_raw(ptr);
+ trace!(harness, "waker.clone");
(*header).state.ref_inc();
raw_waker::<T, S>(header)
}
@@ -61,6 +87,7 @@ where
{
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
+ trace!(harness, "waker.drop");
harness.drop_reference();
}
@@ -71,6 +98,7 @@ where
{
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
+ trace!(harness, "waker.wake");
harness.wake_by_val();
}
@@ -82,6 +110,7 @@ where
{
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
+ trace!(harness, "waker.wake_by_ref");
harness.wake_by_ref();
}
diff --git a/src/runtime/tests/loom_shutdown_join.rs b/src/runtime/tests/loom_shutdown_join.rs
new file mode 100644
index 0000000..6fbc4bf
--- /dev/null
+++ b/src/runtime/tests/loom_shutdown_join.rs
@@ -0,0 +1,28 @@
+use crate::runtime::{Builder, Handle};
+
+#[test]
+fn join_handle_cancel_on_shutdown() {
+ let mut builder = loom::model::Builder::new();
+ builder.preemption_bound = Some(2);
+ builder.check(|| {
+ use futures::future::FutureExt;
+
+ let rt = Builder::new_multi_thread()
+ .worker_threads(2)
+ .build()
+ .unwrap();
+
+ let handle = rt.block_on(async move { Handle::current() });
+
+ let jh1 = handle.spawn(futures::future::pending::<()>());
+
+ drop(rt);
+
+ let jh2 = handle.spawn(futures::future::pending::<()>());
+
+ let err1 = jh1.now_or_never().unwrap().unwrap_err();
+ let err2 = jh2.now_or_never().unwrap().unwrap_err();
+ assert!(err1.is_cancelled());
+ assert!(err2.is_cancelled());
+ });
+}
diff --git a/src/runtime/tests/mod.rs b/src/runtime/tests/mod.rs
index ebb48de..c84ba1b 100644
--- a/src/runtime/tests/mod.rs
+++ b/src/runtime/tests/mod.rs
@@ -4,6 +4,7 @@ cfg_loom! {
mod loom_oneshot;
mod loom_pool;
mod loom_queue;
+ mod loom_shutdown_join;
}
cfg_not_loom! {
diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs
index a34526f..45a3e99 100644
--- a/src/runtime/tests/task.rs
+++ b/src/runtime/tests/task.rs
@@ -79,7 +79,7 @@ static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None);
impl Runtime {
fn tick(&self) -> usize {
- self.tick_max(usize::max_value())
+ self.tick_max(usize::MAX)
}
fn tick_max(&self, max: usize) -> usize {
diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs
index 47f8ee3..96312d3 100644
--- a/src/runtime/thread_pool/mod.rs
+++ b/src/runtime/thread_pool/mod.rs
@@ -90,11 +90,17 @@ impl Spawner {
/// Spawns a future onto the thread pool
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
- F: Future + Send + 'static,
+ F: crate::future::Future + Send + 'static,
F::Output: Send + 'static,
{
let (task, handle) = task::joinable(future);
- self.shared.schedule(task, false);
+
+ if let Err(task) = self.shared.schedule(task, false) {
+ // The newly spawned task could not be scheduled because the runtime
+ // is shutting down. The task must be explicitly shutdown at this point.
+ task.shutdown();
+ }
+
handle
}
diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs
index 86d3f91..70cbddb 100644
--- a/src/runtime/thread_pool/worker.rs
+++ b/src/runtime/thread_pool/worker.rs
@@ -709,16 +709,22 @@ impl task::Schedule for Arc<Worker> {
}
fn schedule(&self, task: Notified) {
- self.shared.schedule(task, false);
+ // Because this is not a newly spawned task, if scheduling fails due to
+ // the runtime shutting down, there is no special work that must happen
+ // here.
+ let _ = self.shared.schedule(task, false);
}
fn yield_now(&self, task: Notified) {
- self.shared.schedule(task, true);
+ // Because this is not a newly spawned task, if scheduling fails due to
+ // the runtime shutting down, there is no special work that must happen
+ // here.
+ let _ = self.shared.schedule(task, true);
}
}
impl Shared {
- pub(super) fn schedule(&self, task: Notified, is_yield: bool) {
+ pub(super) fn schedule(&self, task: Notified, is_yield: bool) -> Result<(), Notified> {
CURRENT.with(|maybe_cx| {
if let Some(cx) = maybe_cx {
// Make sure the task is part of the **current** scheduler.
@@ -726,15 +732,16 @@ impl Shared {
// And the current thread still holds a core
if let Some(core) = cx.core.borrow_mut().as_mut() {
self.schedule_local(core, task, is_yield);
- return;
+ return Ok(());
}
}
}
// Otherwise, use the inject queue
- self.inject.push(task);
+ self.inject.push(task)?;
self.notify_parked();
- });
+ Ok(())
+ })
}
fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) {
@@ -823,7 +830,9 @@ impl Shared {
}
// Drain the injection queue
- while self.inject.pop().is_some() {}
+ while let Some(task) = self.inject.pop() {
+ task.shutdown();
+ }
}
fn ptr_eq(&self, other: &Shared) -> bool {
diff --git a/src/sync/mod.rs b/src/sync/mod.rs
index 5f97c1a..457e6ab 100644
--- a/src/sync/mod.rs
+++ b/src/sync/mod.rs
@@ -428,6 +428,11 @@
//! bounding of any kind.
cfg_sync! {
+ /// Named future types.
+ pub mod futures {
+ pub use super::notify::Notified;
+ }
+
mod barrier;
pub use barrier::{Barrier, BarrierWaitResult};
diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs
index ce857d7..cfd8da0 100644
--- a/src/sync/mpsc/bounded.rs
+++ b/src/sync/mpsc/bounded.rs
@@ -65,7 +65,7 @@ pub struct Receiver<T> {
/// with backpressure.
///
/// The channel will buffer up to the provided number of messages. Once the
-/// buffer is full, attempts to `send` new messages will wait until a message is
+/// buffer is full, attempts to send new messages will wait until a message is
/// received from the channel. The provided buffer capacity must be at least 1.
///
/// All data sent on `Sender` will become available on `Receiver` in the same
@@ -76,7 +76,7 @@ pub struct Receiver<T> {
///
/// If the `Receiver` is disconnected while trying to `send`, the `send` method
/// will return a `SendError`. Similarly, if `Sender` is disconnected while
-/// trying to `recv`, the `recv` method will return a `RecvError`.
+/// trying to `recv`, the `recv` method will return `None`.
///
/// # Panics
///
@@ -887,7 +887,7 @@ impl<T> Sender<T> {
/// let permit = tx.reserve().await.unwrap();
/// assert_eq!(tx.capacity(), 4);
///
- /// // Sending and receiving a value increases the caapcity by one.
+ /// // Sending and receiving a value increases the capacity by one.
/// permit.send(());
/// rx.recv().await.unwrap();
/// assert_eq!(tx.capacity(), 5);
diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs
index a2d2824..0d25ad3 100644
--- a/src/sync/mpsc/error.rs
+++ b/src/sync/mpsc/error.rs
@@ -55,14 +55,18 @@ impl<T> From<SendError<T>> for TrySendError<T> {
/// Error returned by `Receiver`.
#[derive(Debug)]
+#[doc(hidden)]
+#[deprecated(note = "This type is unused because recv returns an Option.")]
pub struct RecvError(());
+#[allow(deprecated)]
impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
+#[allow(deprecated)]
impl Error for RecvError {}
cfg_time! {
diff --git a/src/sync/notify.rs b/src/sync/notify.rs
index 5d2132f..07be759 100644
--- a/src/sync/notify.rs
+++ b/src/sync/notify.rs
@@ -140,7 +140,7 @@ struct Waiter {
_p: PhantomPinned,
}
-/// Future returned from `notified()`
+/// Future returned from [`Notify::notified()`]
#[derive(Debug)]
pub struct Notified<'a> {
/// The `Notify` being received on.
diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs
index af75042..5d42d1c 100644
--- a/src/sync/semaphore.rs
+++ b/src/sync/semaphore.rs
@@ -24,7 +24,55 @@ use std::sync::Arc;
/// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`]
/// utility.
///
+/// # Examples
+///
+/// Basic usage:
+///
+/// ```
+/// use tokio::sync::{Semaphore, TryAcquireError};
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let semaphore = Semaphore::new(3);
+///
+/// let a_permit = semaphore.acquire().await.unwrap();
+/// let two_permits = semaphore.acquire_many(2).await.unwrap();
+///
+/// assert_eq!(semaphore.available_permits(), 0);
+///
+/// let permit_attempt = semaphore.try_acquire();
+/// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits));
+/// }
+/// ```
+///
+/// Use [`Semaphore::acquire_owned`] to move permits across tasks:
+///
+/// ```
+/// use std::sync::Arc;
+/// use tokio::sync::Semaphore;
+///
+/// #[tokio::main]
+/// async fn main() {
+/// let semaphore = Arc::new(Semaphore::new(3));
+/// let mut join_handles = Vec::new();
+///
+/// for _ in 0..5 {
+/// let permit = semaphore.clone().acquire_owned().await.unwrap();
+/// join_handles.push(tokio::spawn(async move {
+/// // perform task...
+/// // explicitly own `permit` in the task
+/// drop(permit);
+/// }));
+/// }
+///
+/// for handle in join_handles {
+/// handle.await.unwrap();
+/// }
+/// }
+/// ```
+///
/// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html
+/// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned
#[derive(Debug)]
pub struct Semaphore {
/// The low level semaphore
@@ -79,6 +127,15 @@ impl Semaphore {
}
/// Creates a new semaphore with the initial number of permits.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Semaphore;
+ ///
+ /// static SEM: Semaphore = Semaphore::const_new(10);
+ /// ```
+ ///
#[cfg(all(feature = "parking_lot", not(all(loom, test))))]
#[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))]
pub const fn const_new(permits: usize) -> Self {
@@ -105,6 +162,26 @@ impl Semaphore {
/// Otherwise, this returns a [`SemaphorePermit`] representing the
/// acquired permit.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Semaphore;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let semaphore = Semaphore::new(2);
+ ///
+ /// let permit_1 = semaphore.acquire().await.unwrap();
+ /// assert_eq!(semaphore.available_permits(), 1);
+ ///
+ /// let permit_2 = semaphore.acquire().await.unwrap();
+ /// assert_eq!(semaphore.available_permits(), 0);
+ ///
+ /// drop(permit_1);
+ /// assert_eq!(semaphore.available_permits(), 1);
+ /// }
+ /// ```
+ ///
/// [`AcquireError`]: crate::sync::AcquireError
/// [`SemaphorePermit`]: crate::sync::SemaphorePermit
pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> {
@@ -121,6 +198,20 @@ impl Semaphore {
/// Otherwise, this returns a [`SemaphorePermit`] representing the
/// acquired permits.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::Semaphore;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let semaphore = Semaphore::new(5);
+ ///
+ /// let permit = semaphore.acquire_many(3).await.unwrap();
+ /// assert_eq!(semaphore.available_permits(), 2);
+ /// }
+ /// ```
+ ///
/// [`AcquireError`]: crate::sync::AcquireError
/// [`SemaphorePermit`]: crate::sync::SemaphorePermit
pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
@@ -137,6 +228,25 @@ impl Semaphore {
/// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise,
/// this returns a [`SemaphorePermit`] representing the acquired permits.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{Semaphore, TryAcquireError};
+ ///
+ /// # fn main() {
+ /// let semaphore = Semaphore::new(2);
+ ///
+ /// let permit_1 = semaphore.try_acquire().unwrap();
+ /// assert_eq!(semaphore.available_permits(), 1);
+ ///
+ /// let permit_2 = semaphore.try_acquire().unwrap();
+ /// assert_eq!(semaphore.available_permits(), 0);
+ ///
+ /// let permit_3 = semaphore.try_acquire();
+ /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits));
+ /// # }
+ /// ```
+ ///
/// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed
/// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits
/// [`SemaphorePermit`]: crate::sync::SemaphorePermit
@@ -153,8 +263,24 @@ impl Semaphore {
/// Tries to acquire `n` permits from the semaphore.
///
/// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`]
- /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise,
- /// this returns a [`SemaphorePermit`] representing the acquired permits.
+ /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left.
+ /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::{Semaphore, TryAcquireError};
+ ///
+ /// # fn main() {
+ /// let semaphore = Semaphore::new(4);
+ ///
+ /// let permit_1 = semaphore.try_acquire_many(3).unwrap();
+ /// assert_eq!(semaphore.available_permits(), 1);
+ ///
+ /// let permit_2 = semaphore.try_acquire_many(2);
+ /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits));
+ /// # }
+ /// ```
///
/// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed
/// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits
@@ -176,6 +302,32 @@ impl Semaphore {
/// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the
/// acquired permit.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::sync::Arc;
+ /// use tokio::sync::Semaphore;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let semaphore = Arc::new(Semaphore::new(3));
+ /// let mut join_handles = Vec::new();
+ ///
+ /// for _ in 0..5 {
+ /// let permit = semaphore.clone().acquire_owned().await.unwrap();
+ /// join_handles.push(tokio::spawn(async move {
+ /// // perform task...
+ /// // explicitly own `permit` in the task
+ /// drop(permit);
+ /// }));
+ /// }
+ ///
+ /// for handle in join_handles {
+ /// handle.await.unwrap();
+ /// }
+ /// }
+ /// ```
+ ///
/// [`Arc`]: std::sync::Arc
/// [`AcquireError`]: crate::sync::AcquireError
/// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit
@@ -194,6 +346,32 @@ impl Semaphore {
/// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the
/// acquired permit.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::sync::Arc;
+ /// use tokio::sync::Semaphore;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let semaphore = Arc::new(Semaphore::new(10));
+ /// let mut join_handles = Vec::new();
+ ///
+ /// for _ in 0..5 {
+ /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap();
+ /// join_handles.push(tokio::spawn(async move {
+ /// // perform task...
+ /// // explicitly own `permit` in the task
+ /// drop(permit);
+ /// }));
+ /// }
+ ///
+ /// for handle in join_handles {
+ /// handle.await.unwrap();
+ /// }
+ /// }
+ /// ```
+ ///
/// [`Arc`]: std::sync::Arc
/// [`AcquireError`]: crate::sync::AcquireError
/// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit
@@ -216,6 +394,26 @@ impl Semaphore {
/// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the
/// acquired permit.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::sync::Arc;
+ /// use tokio::sync::{Semaphore, TryAcquireError};
+ ///
+ /// # fn main() {
+ /// let semaphore = Arc::new(Semaphore::new(2));
+ ///
+ /// let permit_1 = Arc::clone(&semaphore).try_acquire_owned().unwrap();
+ /// assert_eq!(semaphore.available_permits(), 1);
+ ///
+ /// let permit_2 = Arc::clone(&semaphore).try_acquire_owned().unwrap();
+ /// assert_eq!(semaphore.available_permits(), 0);
+ ///
+ /// let permit_3 = semaphore.try_acquire_owned();
+ /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits));
+ /// # }
+ /// ```
+ ///
/// [`Arc`]: std::sync::Arc
/// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed
/// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits
@@ -238,6 +436,23 @@ impl Semaphore {
/// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the
/// acquired permit.
///
+ /// # Examples
+ ///
+ /// ```
+ /// use std::sync::Arc;
+ /// use tokio::sync::{Semaphore, TryAcquireError};
+ ///
+ /// # fn main() {
+ /// let semaphore = Arc::new(Semaphore::new(4));
+ ///
+ /// let permit_1 = Arc::clone(&semaphore).try_acquire_many_owned(3).unwrap();
+ /// assert_eq!(semaphore.available_permits(), 1);
+ ///
+ /// let permit_2 = semaphore.try_acquire_many_owned(2);
+ /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits));
+ /// # }
+ /// ```
+ ///
/// [`Arc`]: std::sync::Arc
/// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed
/// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits
diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs
index 5917204..8616007 100644
--- a/src/sync/task/atomic_waker.rs
+++ b/src/sync/task/atomic_waker.rs
@@ -29,7 +29,7 @@ pub(crate) struct AtomicWaker {
// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell
// stores a `Waker` value produced by calls to `register` and many threads can
-// race to take the waker by calling `wake.
+// race to take the waker by calling `wake`.
//
// If a new `Waker` instance is produced by calling `register` before an existing
// one is consumed, then the existing one is overwritten.
diff --git a/src/sync/watch.rs b/src/sync/watch.rs
index db65e5a..42d417a 100644
--- a/src/sync/watch.rs
+++ b/src/sync/watch.rs
@@ -417,6 +417,28 @@ impl<T> Sender<T> {
Receiver::from_shared(version, shared)
}
}
+
+ /// Returns the number of receivers that currently exist
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::sync::watch;
+ ///
+ /// #[tokio::main]
+ /// async fn main() {
+ /// let (tx, rx1) = watch::channel("hello");
+ ///
+ /// assert_eq!(1, tx.receiver_count());
+ ///
+ /// let mut _rx2 = rx1.clone();
+ ///
+ /// assert_eq!(2, tx.receiver_count());
+ /// }
+ /// ```
+ pub fn receiver_count(&self) -> usize {
+ self.shared.ref_count_rx.load(Relaxed)
+ }
}
impl<T> Drop for Sender<T> {
diff --git a/src/task/mod.rs b/src/task/mod.rs
index 7255535..25dab0c 100644
--- a/src/task/mod.rs
+++ b/src/task/mod.rs
@@ -86,7 +86,7 @@
//! ```
//!
//! Again, like `std::thread`'s [`JoinHandle` type][thread_join], if the spawned
-//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]`. For
+//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]. For
//! example:
//!
//! ```
diff --git a/src/time/clock.rs b/src/time/clock.rs
index c5ef86b..a0ff621 100644
--- a/src/time/clock.rs
+++ b/src/time/clock.rs
@@ -7,7 +7,7 @@
//! configurable.
cfg_not_test_util! {
- use crate::time::{Duration, Instant};
+ use crate::time::{Instant};
#[derive(Debug, Clone)]
pub(crate) struct Clock {}
@@ -24,14 +24,6 @@ cfg_not_test_util! {
pub(crate) fn now(&self) -> Instant {
now()
}
-
- pub(crate) fn is_paused(&self) -> bool {
- false
- }
-
- pub(crate) fn advance(&self, _dur: Duration) {
- unreachable!();
- }
}
}
@@ -121,10 +113,9 @@ cfg_test_util! {
/// runtime.
pub async fn advance(duration: Duration) {
let clock = clock().expect("time cannot be frozen from outside the Tokio runtime");
- let until = clock.now() + duration;
clock.advance(duration);
- crate::time::sleep_until(until).await;
+ crate::task::yield_now().await;
}
/// Return the current instant, factoring in frozen time.
diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs
index 08edab3..168e0b9 100644
--- a/src/time/driver/entry.rs
+++ b/src/time/driver/entry.rs
@@ -68,7 +68,7 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};
type TimerResult = Result<(), crate::time::error::Error>;
-const STATE_DEREGISTERED: u64 = u64::max_value();
+const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
@@ -85,10 +85,10 @@ const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
/// requires only the driver lock.
pub(super) struct StateCell {
/// Holds either the scheduled expiration time for this timer, or (if the
- /// timer has been fired and is unregistered), `u64::max_value()`.
+ /// timer has been fired and is unregistered), `u64::MAX`.
state: AtomicU64,
/// If the timer is fired (an Acquire order read on state shows
- /// `u64::max_value()`), holds the result that should be returned from
+ /// `u64::MAX`), holds the result that should be returned from
/// polling the timer. Otherwise, the contents are unspecified and reading
/// without holding the driver lock is undefined behavior.
result: UnsafeCell<TimerResult>,
@@ -125,7 +125,7 @@ impl StateCell {
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
- if cur_state == u64::max_value() {
+ if cur_state == u64::MAX {
None
} else {
Some(cur_state)
@@ -271,7 +271,7 @@ impl StateCell {
/// ordering, but is conservative - if it returns false, the timer is
/// definitely _not_ registered.
pub(super) fn might_be_registered(&self) -> bool {
- self.state.load(Ordering::Relaxed) != u64::max_value()
+ self.state.load(Ordering::Relaxed) != u64::MAX
}
}
@@ -591,7 +591,7 @@ impl TimerHandle {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
// mark this as being on the pending queue in cached_when
- self.inner.as_ref().set_cached_when(u64::max_value());
+ self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs
index 3eb1004..37d2231 100644
--- a/src/time/driver/mod.rs
+++ b/src/time/driver/mod.rs
@@ -91,6 +91,15 @@ pub(crate) struct Driver<P: Park + 'static> {
/// Parker to delegate to
park: P,
+
+ // When `true`, a call to `park_timeout` should immediately return and time
+ // should not advance. One reason for this to be `true` is if the task
+ // passed to `Runtime::block_on` called `task::yield_now()`.
+ //
+ // While it may look racy, it only has any effect when the clock is paused
+ // and pausing the clock is restricted to a single-threaded runtime.
+ #[cfg(feature = "test-util")]
+ did_wake: Arc<AtomicBool>,
}
/// A structure which handles conversion from Instants to u64 timestamps.
@@ -178,6 +187,8 @@ where
time_source,
handle: Handle::new(Arc::new(inner)),
park,
+ #[cfg(feature = "test-util")]
+ did_wake: Arc::new(AtomicBool::new(false)),
}
}
@@ -192,8 +203,6 @@ where
}
fn park_internal(&mut self, limit: Option<Duration>) -> Result<(), P::Error> {
- let clock = &self.time_source.clock;
-
let mut lock = self.handle.get().state.lock();
assert!(!self.handle.is_shutdown());
@@ -217,26 +226,14 @@ where
duration = std::cmp::min(limit, duration);
}
- if clock.is_paused() {
- self.park.park_timeout(Duration::from_secs(0))?;
-
- // Simulate advancing time
- clock.advance(duration);
- } else {
- self.park.park_timeout(duration)?;
- }
+ self.park_timeout(duration)?;
} else {
self.park.park_timeout(Duration::from_secs(0))?;
}
}
None => {
if let Some(duration) = limit {
- if clock.is_paused() {
- self.park.park_timeout(Duration::from_secs(0))?;
- clock.advance(duration);
- } else {
- self.park.park_timeout(duration)?;
- }
+ self.park_timeout(duration)?;
} else {
self.park.park()?;
}
@@ -248,6 +245,39 @@ where
Ok(())
}
+
+ cfg_test_util! {
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> {
+ let clock = &self.time_source.clock;
+
+ if clock.is_paused() {
+ self.park.park_timeout(Duration::from_secs(0))?;
+
+ // If the time driver was woken, then the park completed
+ // before the "duration" elapsed (usually caused by a
+ // yield in `Runtime::block_on`). In this case, we don't
+ // advance the clock.
+ if !self.did_wake() {
+ // Simulate advancing time
+ clock.advance(duration);
+ }
+ } else {
+ self.park.park_timeout(duration)?;
+ }
+
+ Ok(())
+ }
+
+ fn did_wake(&self) -> bool {
+ self.did_wake.swap(false, Ordering::SeqCst)
+ }
+ }
+
+ cfg_not_test_util! {
+ fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> {
+ self.park.park_timeout(duration)
+ }
+ }
}
impl Handle {
@@ -387,11 +417,11 @@ impl<P> Park for Driver<P>
where
P: Park + 'static,
{
- type Unpark = P::Unpark;
+ type Unpark = TimerUnpark<P>;
type Error = P::Error;
fn unpark(&self) -> Self::Unpark {
- self.park.unpark()
+ TimerUnpark::new(self)
}
fn park(&mut self) -> Result<(), Self::Error> {
@@ -426,6 +456,33 @@ where
}
}
+pub(crate) struct TimerUnpark<P: Park + 'static> {
+ inner: P::Unpark,
+
+ #[cfg(feature = "test-util")]
+ did_wake: Arc<AtomicBool>,
+}
+
+impl<P: Park + 'static> TimerUnpark<P> {
+ fn new(driver: &Driver<P>) -> TimerUnpark<P> {
+ TimerUnpark {
+ inner: driver.park.unpark(),
+
+ #[cfg(feature = "test-util")]
+ did_wake: driver.did_wake.clone(),
+ }
+ }
+}
+
+impl<P: Park + 'static> Unpark for TimerUnpark<P> {
+ fn unpark(&self) {
+ #[cfg(feature = "test-util")]
+ self.did_wake.store(true, Ordering::SeqCst);
+
+ self.inner.unpark();
+ }
+}
+
// ===== impl Inner =====
impl Inner {
diff --git a/src/time/driver/wheel/mod.rs b/src/time/driver/wheel/mod.rs
index 24bf517..5a40f6d 100644
--- a/src/time/driver/wheel/mod.rs
+++ b/src/time/driver/wheel/mod.rs
@@ -119,7 +119,7 @@ impl Wheel {
pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) {
unsafe {
let when = item.as_ref().cached_when();
- if when == u64::max_value() {
+ if when == u64::MAX {
self.pending.remove(item);
} else {
debug_assert!(
diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs
index 480ea09..a74f562 100644
--- a/src/util/linked_list.rs
+++ b/src/util/linked_list.rs
@@ -50,6 +50,7 @@ pub(crate) unsafe trait Link {
type Target;
/// Convert the handle to a raw pointer without consuming the handle
+ #[allow(clippy::wrong_self_convention)]
fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target>;
/// Convert the raw pointer to a handle
diff --git a/src/util/wake.rs b/src/util/wake.rs
index 001577d..5773937 100644
--- a/src/util/wake.rs
+++ b/src/util/wake.rs
@@ -54,11 +54,7 @@ unsafe fn inc_ref_count<T: Wake>(data: *const ()) {
let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T));
// Now increase refcount, but don't drop new refcount either
- let arc_clone: ManuallyDrop<_> = arc.clone();
-
- // Drop explicitly to avoid clippy warnings
- drop(arc);
- drop(arc_clone);
+ let _arc_clone: ManuallyDrop<_> = arc.clone();
}
unsafe fn clone_arc_raw<T: Wake>(data: *const ()) -> RawWaker {