diff options
author | Joel Galenson <jgalenson@google.com> | 2021-06-22 09:28:11 -0700 |
---|---|---|
committer | Joel Galenson <jgalenson@google.com> | 2021-06-22 09:28:30 -0700 |
commit | b53dd06ad19d902c2155b2f616725b14b423a776 (patch) | |
tree | b92904fad1f985434ad7622c3960e6e670b1c918 /src | |
parent | 8c2e0e8165f4f0132f3a5c78337fbba15b102768 (diff) | |
download | tokio-b53dd06ad19d902c2155b2f616725b14b423a776.tar.gz |
Upgrade rust/crates/tokio to 1.7.1
Test: make
Change-Id: I7ebd839df13023db6f2057e09d8b73967436b856
Diffstat (limited to 'src')
49 files changed, 2017 insertions, 327 deletions
diff --git a/src/doc/mod.rs b/src/doc/mod.rs new file mode 100644 index 0000000..12c2247 --- /dev/null +++ b/src/doc/mod.rs @@ -0,0 +1,23 @@ +//! Types which are documented locally in the Tokio crate, but does not actually +//! live here. +//! +//! **Note** this module is only visible on docs.rs, you cannot use it directly +//! in your own code. + +/// The name of a type which is not defined here. +/// +/// This is typically used as an alias for another type, like so: +/// +/// ```rust,ignore +/// /// See [some::other::location](https://example.com). +/// type DEFINED_ELSEWHERE = crate::doc::NotDefinedHere; +/// ``` +/// +/// This type is uninhabitable like the [`never` type] to ensure that no one +/// will ever accidentally use it. +/// +/// [`never` type]: https://doc.rust-lang.org/std/primitive.never.html +pub enum NotDefinedHere {} + +pub mod os; +pub mod winapi; diff --git a/src/doc/os.rs b/src/doc/os.rs new file mode 100644 index 0000000..0ddf869 --- /dev/null +++ b/src/doc/os.rs @@ -0,0 +1,26 @@ +//! See [std::os](https://doc.rust-lang.org/std/os/index.html). + +/// Platform-specific extensions to `std` for Windows. +/// +/// See [std::os::windows](https://doc.rust-lang.org/std/os/windows/index.html). +pub mod windows { + /// Windows-specific extensions to general I/O primitives. + /// + /// See [std::os::windows::io](https://doc.rust-lang.org/std/os/windows/io/index.html). + pub mod io { + /// See [std::os::windows::io::RawHandle](https://doc.rust-lang.org/std/os/windows/io/type.RawHandle.html) + pub type RawHandle = crate::doc::NotDefinedHere; + + /// See [std::os::windows::io::AsRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html) + pub trait AsRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html#tymethod.as_raw_handle) + fn as_raw_handle(&self) -> RawHandle; + } + + /// See [std::os::windows::io::FromRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html) + pub trait FromRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html#tymethod.from_raw_handle) + unsafe fn from_raw_handle(handle: RawHandle) -> Self; + } + } +} diff --git a/src/doc/winapi.rs b/src/doc/winapi.rs new file mode 100644 index 0000000..be68749 --- /dev/null +++ b/src/doc/winapi.rs @@ -0,0 +1,66 @@ +//! See [winapi]. +//! +//! [winapi]: https://docs.rs/winapi + +/// See [winapi::shared](https://docs.rs/winapi/*/winapi/shared/index.html). +pub mod shared { + /// See [winapi::shared::winerror](https://docs.rs/winapi/*/winapi/shared/winerror/index.html). + #[allow(non_camel_case_types)] + pub mod winerror { + /// See [winapi::shared::winerror::ERROR_ACCESS_DENIED][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_ACCESS_DENIED.html + pub type ERROR_ACCESS_DENIED = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_PIPE_BUSY][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_PIPE_BUSY.html + pub type ERROR_PIPE_BUSY = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_MORE_DATA][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_MORE_DATA.html + pub type ERROR_MORE_DATA = crate::doc::NotDefinedHere; + } +} + +/// See [winapi::um](https://docs.rs/winapi/*/winapi/um/index.html). +pub mod um { + /// See [winapi::um::winbase](https://docs.rs/winapi/*/winapi/um/winbase/index.html). + #[allow(non_camel_case_types)] + pub mod winbase { + /// See [winapi::um::winbase::PIPE_TYPE_MESSAGE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_MESSAGE.html + pub type PIPE_TYPE_MESSAGE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_TYPE_BYTE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_BYTE.html + pub type PIPE_TYPE_BYTE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_CLIENT_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_CLIENT_END.html + pub type PIPE_CLIENT_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_SERVER_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_SERVER_END.html + pub type PIPE_SERVER_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::SECURITY_IDENTIFICATION][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.SECURITY_IDENTIFICATION.html + pub type SECURITY_IDENTIFICATION = crate::doc::NotDefinedHere; + } + + /// See [winapi::um::minwinbase](https://docs.rs/winapi/*/winapi/um/minwinbase/index.html). + #[allow(non_camel_case_types)] + pub mod minwinbase { + /// See [winapi::um::minwinbase::SECURITY_ATTRIBUTES][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/minwinbase/constant.SECURITY_ATTRIBUTES.html + pub type SECURITY_ATTRIBUTES = crate::doc::NotDefinedHere; + } +} diff --git a/src/fs/file.rs b/src/fs/file.rs index abd6e8c..5c06e73 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -491,18 +491,14 @@ impl AsyncRead for File { loop { match inner.state { Idle(ref mut buf_cell) => { - let buf = buf_cell.as_mut().unwrap(); + let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { buf.copy_to(dst); + *buf_cell = Some(buf); return Ready(Ok(())); } - if let Some(x) = try_nonblocking_read(me.std.as_ref(), dst) { - return Ready(x); - } - - let mut buf = buf_cell.take().unwrap(); buf.ensure_capacity_for(dst); let std = me.std.clone(); @@ -760,186 +756,3 @@ impl Inner { } } } - -#[cfg(all(target_os = "linux", not(test)))] -pub(crate) fn try_nonblocking_read( - file: &crate::fs::sys::File, - dst: &mut ReadBuf<'_>, -) -> Option<std::io::Result<()>> { - use std::sync::atomic::{AtomicBool, Ordering}; - - static NONBLOCKING_READ_SUPPORTED: AtomicBool = AtomicBool::new(true); - if !NONBLOCKING_READ_SUPPORTED.load(Ordering::Relaxed) { - return None; - } - let out = preadv2::preadv2_safe(file, dst, -1, preadv2::RWF_NOWAIT); - if let Err(err) = &out { - match err.raw_os_error() { - Some(libc::ENOSYS) => { - NONBLOCKING_READ_SUPPORTED.store(false, Ordering::Relaxed); - return None; - } - Some(libc::ENOTSUP) | Some(libc::EAGAIN) => return None, - _ => {} - } - } - Some(out) -} - -#[cfg(any(not(target_os = "linux"), test))] -pub(crate) fn try_nonblocking_read( - _file: &crate::fs::sys::File, - _dst: &mut ReadBuf<'_>, -) -> Option<std::io::Result<()>> { - None -} - -#[cfg(target_os = "linux")] -mod preadv2 { - use libc::{c_int, c_long, c_void, iovec, off_t, ssize_t}; - use std::os::unix::prelude::AsRawFd; - - use crate::io::ReadBuf; - - pub(crate) fn preadv2_safe( - file: &std::fs::File, - dst: &mut ReadBuf<'_>, - offset: off_t, - flags: c_int, - ) -> std::io::Result<()> { - unsafe { - /* We have to defend against buffer overflows manually here. The slice API makes - * this fairly straightforward. */ - let unfilled = dst.unfilled_mut(); - let mut iov = iovec { - iov_base: unfilled.as_mut_ptr() as *mut c_void, - iov_len: unfilled.len(), - }; - /* We take a File object rather than an fd as reading from a sensitive fd may confuse - * other unsafe code that assumes that only they have access to that fd. */ - let bytes_read = preadv2( - file.as_raw_fd(), - &mut iov as *mut iovec as *const iovec, - 1, - offset, - flags, - ); - if bytes_read < 0 { - Err(std::io::Error::last_os_error()) - } else { - /* preadv2 returns the number of bytes read, e.g. the number of bytes that have - * written into `unfilled`. So it's safe to assume that the data is now - * initialised */ - dst.assume_init(bytes_read as usize); - dst.advance(bytes_read as usize); - Ok(()) - } - } - } - - #[cfg(test)] - mod test { - use super::*; - - #[test] - fn test_preadv2_safe() { - use std::io::{Seek, Write}; - use std::mem::MaybeUninit; - use tempfile::tempdir; - - let tmp = tempdir().unwrap(); - let filename = tmp.path().join("file"); - const MESSAGE: &[u8] = b"Hello this is a test"; - { - let mut f = std::fs::File::create(&filename).unwrap(); - f.write_all(MESSAGE).unwrap(); - } - let f = std::fs::File::open(&filename).unwrap(); - - let mut buf = [MaybeUninit::<u8>::new(0); 50]; - let mut br = ReadBuf::uninit(&mut buf); - - // Basic use: - preadv2_safe(&f, &mut br, 0, 0).unwrap(); - assert_eq!(br.initialized().len(), MESSAGE.len()); - assert_eq!(br.filled(), MESSAGE); - - // Here we check that offset works, but also that appending to a non-empty buffer - // behaves correctly WRT initialisation. - preadv2_safe(&f, &mut br, 5, 0).unwrap(); - assert_eq!(br.initialized().len(), MESSAGE.len() * 2 - 5); - assert_eq!(br.filled(), b"Hello this is a test this is a test".as_ref()); - - // offset of -1 means use the current cursor. This has not been advanced by the - // previous reads because we specified an offset there. - preadv2_safe(&f, &mut br, -1, 0).unwrap(); - assert_eq!(br.remaining(), 0); - assert_eq!( - br.filled(), - b"Hello this is a test this is a testHello this is a".as_ref() - ); - - // but the offset should have been advanced by that read - br.clear(); - preadv2_safe(&f, &mut br, -1, 0).unwrap(); - assert_eq!(br.filled(), b" test"); - - // This should be in cache, so RWF_NOWAIT should work, but it not being in cache - // (EAGAIN) or not supported by the underlying filesystem (ENOTSUP) is fine too. - br.clear(); - match preadv2_safe(&f, &mut br, 0, RWF_NOWAIT) { - Ok(()) => assert_eq!(br.filled(), MESSAGE), - Err(e) => assert!(matches!( - e.raw_os_error(), - Some(libc::ENOTSUP) | Some(libc::EAGAIN) - )), - } - - // Test handling large offsets - { - // I hope the underlying filesystem supports sparse files - let mut w = std::fs::OpenOptions::new() - .write(true) - .open(&filename) - .unwrap(); - w.set_len(0x1_0000_0000).unwrap(); - w.seek(std::io::SeekFrom::Start(0x1_0000_0000)).unwrap(); - w.write_all(b"This is a Large File").unwrap(); - } - - br.clear(); - preadv2_safe(&f, &mut br, 0x1_0000_0008, 0).unwrap(); - assert_eq!(br.filled(), b"a Large File"); - } - } - - fn pos_to_lohi(offset: off_t) -> (c_long, c_long) { - // 64-bit offset is split over high and low 32-bits on 32-bit architectures. - // 64-bit architectures still have high and low arguments, but only the low - // one is inspected. See pos_from_hilo in linux/fs/read_write.c. - const HALF_LONG_BITS: usize = core::mem::size_of::<c_long>() * 8 / 2; - ( - offset as c_long, - // We want to shift this off_t value by size_of::<c_long>(). We can't do - // it in one shift because if they're both 64-bits we'd be doing u64 >> 64 - // which is implementation defined. Instead do it in two halves: - ((offset >> HALF_LONG_BITS) >> HALF_LONG_BITS) as c_long, - ) - } - - pub(crate) const RWF_NOWAIT: c_int = 0x00000008; - unsafe fn preadv2( - fd: c_int, - iov: *const iovec, - iovcnt: c_int, - offset: off_t, - flags: c_int, - ) -> ssize_t { - // Call via libc::syscall rather than libc::preadv2. preadv2 is only supported by glibc - // and only since v2.26. By using syscall we don't need to worry about compatiblity with - // old glibc versions and it will work on Android and musl too. The downside is that you - // can't use `LD_PRELOAD` tricks any more to intercept these calls. - let (lo, hi) = pos_to_lohi(offset); - libc::syscall(libc::SYS_preadv2, fd, iov, iovcnt, lo, hi, flags) as ssize_t - } -} diff --git a/src/future/mod.rs b/src/future/mod.rs index f7d93c9..96483ac 100644 --- a/src/future/mod.rs +++ b/src/future/mod.rs @@ -22,3 +22,14 @@ cfg_sync! { mod block_on; pub(crate) use block_on::block_on; } + +cfg_trace! { + mod trace; + pub(crate) use trace::InstrumentedFuture as Future; +} + +cfg_not_trace! { + cfg_rt! { + pub(crate) use std::future::Future; + } +} diff --git a/src/future/trace.rs b/src/future/trace.rs new file mode 100644 index 0000000..28789a6 --- /dev/null +++ b/src/future/trace.rs @@ -0,0 +1,11 @@ +use std::future::Future; + +pub(crate) trait InstrumentedFuture: Future { + fn id(&self) -> Option<tracing::Id>; +} + +impl<F: Future> InstrumentedFuture for tracing::instrument::Instrumented<F> { + fn id(&self) -> Option<tracing::Id> { + self.span().id() + } +} diff --git a/src/io/async_write.rs b/src/io/async_write.rs index 569fb9c..7ec1a30 100644 --- a/src/io/async_write.rs +++ b/src/io/async_write.rs @@ -45,7 +45,11 @@ use std::task::{Context, Poll}; pub trait AsyncWrite { /// Attempt to write bytes from `buf` into the object. /// - /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. If successful, + /// then it must be guaranteed that `n <= buf.len()`. A return value of `0` + /// typically means that the underlying object is no longer able to accept + /// bytes and will likely not be able to in the future as well, or that the + /// buffer provided is empty. /// /// If the object is not ready for writing, the method returns /// `Poll::Pending` and arranges for the current task (via diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index e715f9d..878676f 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -108,6 +108,8 @@ cfg_io_util! { /// This function does not provide any guarantees about whether it /// completes immediately or asynchronously /// + /// # Return + /// /// If the return value of this method is `Ok(n)`, then it must be /// guaranteed that `0 <= n <= buf.len()`. A nonzero `n` value indicates /// that the buffer `buf` has been filled in with `n` bytes of data from @@ -180,9 +182,14 @@ cfg_io_util! { /// /// # Return /// - /// On a successful read, the number of read bytes is returned. If the - /// supplied buffer is not empty and the function returns `Ok(0)` then - /// the source has reached an "end-of-file" event. + /// A nonzero `n` value indicates that the buffer `buf` has been filled + /// in with `n` bytes of data from this source. If `n` is `0`, then it + /// can indicate one of two scenarios: + /// + /// 1. This reader has reached its "end of file" and will likely no longer + /// be able to produce bytes. Note that this does not mean that the + /// reader will *always* no longer be able to produce bytes. + /// 2. The buffer specified had a remaining capacity of zero. /// /// # Errors /// @@ -579,7 +586,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); /// - /// assert_eq!(i64::min_value(), reader.read_i64().await?); + /// assert_eq!(i64::MIN, reader.read_i64().await?); /// Ok(()) /// } /// ``` @@ -659,7 +666,7 @@ cfg_io_util! { /// 0, 0, 0, 0, 0, 0, 0, 0 /// ]); /// - /// assert_eq!(i128::min_value(), reader.read_i128().await?); + /// assert_eq!(i128::MIN, reader.read_i128().await?); /// Ok(()) /// } /// ``` diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index 2510ccd..4651a99 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -621,8 +621,8 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i64(i64::min_value()).await?; - /// writer.write_i64(i64::max_value()).await?; + /// writer.write_i64(i64::MIN).await?; + /// writer.write_i64(i64::MAX).await?; /// /// assert_eq!(writer, b"\x80\x00\x00\x00\x00\x00\x00\x00\x7f\xff\xff\xff\xff\xff\xff\xff"); /// Ok(()) @@ -699,7 +699,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i128(i128::min_value()).await?; + /// writer.write_i128(i128::MIN).await?; /// /// assert_eq!(writer, vec![ /// 0x80, 0, 0, 0, 0, 0, 0, 0, @@ -930,8 +930,8 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i64_le(i64::min_value()).await?; - /// writer.write_i64_le(i64::max_value()).await?; + /// writer.write_i64_le(i64::MIN).await?; + /// writer.write_i64_le(i64::MAX).await?; /// /// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f"); /// Ok(()) @@ -1008,7 +1008,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i128_le(i128::min_value()).await?; + /// writer.write_i128_le(i128::MIN).await?; /// /// assert_eq!(writer, vec![ /// 0, 0, 0, 0, 0, 0, 0, diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index cc65ef2..c4d6842 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -198,7 +198,7 @@ impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { // it should be safe to assume that remainder fits within an i64 as the alternative // means we managed to allocate 8 exbibytes and that's absurd. // But it's not out of the realm of possibility for some weird underlying reader to - // support seeking by i64::min_value() so we need to handle underflow when subtracting + // support seeking by i64::MIN so we need to handle underflow when subtracting // remainder. if let Some(offset) = n.checked_sub(remainder) { self.as_mut() diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index 9238665..ff3d9db 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -1,8 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io; +use std::io::{self, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -146,6 +146,34 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { } } +/// Seek to an offset, in bytes, in the underlying stream. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying stream would be at if the `BufStream` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl<RW: AsyncRead + AsyncWrite + AsyncSeek> AsyncSeek for BufStream<RW> { + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> { + self.project().inner.start_seek(position) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + self.project().inner.poll_complete(cx) + } +} + impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { self.project().inner.poll_fill_buf(cx) @@ -9,6 +9,7 @@ rust_2018_idioms, unreachable_pub )] +#![deny(unused_must_use)] #![cfg_attr(docsrs, deny(broken_intra_doc_links))] #![doc(test( no_crate_inject, @@ -442,6 +443,28 @@ mod util; /// ``` pub mod stream {} +// local re-exports of platform specific things, allowing for decent +// documentation to be shimmed in on docs.rs + +#[cfg(docsrs)] +pub mod doc; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::os; + +#[cfg(not(docsrs))] +#[allow(unused)] +pub(crate) use std::os; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::winapi; + +#[cfg(all(not(docsrs), windows, feature = "net"))] +#[allow(unused)] +pub(crate) use ::winapi; + cfg_macros! { /// Implementation detail of the `select!` macro. This macro is **not** /// intended to be used as part of the public API and is permitted to @@ -453,15 +476,20 @@ cfg_macros! { #[cfg(feature = "rt-multi-thread")] #[cfg(not(test))] // Work around for rust-lang/rust#62127 #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] pub use tokio_macros::main; #[cfg(feature = "rt-multi-thread")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] pub use tokio_macros::test; cfg_not_rt_multi_thread! { #[cfg(not(test))] // Work around for rust-lang/rust#62127 + #[doc(inline)] pub use tokio_macros::main_rt as main; + + #[doc(inline)] pub use tokio_macros::test_rt as test; } } @@ -469,7 +497,10 @@ cfg_macros! { // Always fail if rt is not enabled. cfg_not_rt! { #[cfg(not(test))] + #[doc(inline)] pub use tokio_macros::main_fail as main; + + #[doc(inline)] pub use tokio_macros::test_fail as test; } } diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs index 3442612..1e77556 100644 --- a/src/macros/cfg.rs +++ b/src/macros/cfg.rs @@ -157,7 +157,6 @@ macro_rules! cfg_macros { $( #[cfg(feature = "macros")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - #[doc(inline)] $item )* } @@ -183,6 +182,16 @@ macro_rules! cfg_net_unix { } } +macro_rules! cfg_net_windows { + ($($item:item)*) => { + $( + #[cfg(all(any(docsrs, windows), feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "net"))))] + $item + )* + } +} + macro_rules! cfg_process { ($($item:item)*) => { $( diff --git a/src/macros/select.rs b/src/macros/select.rs index f98ebff..371a3de 100644 --- a/src/macros/select.rs +++ b/src/macros/select.rs @@ -398,7 +398,7 @@ macro_rules! select { // set the appropriate bit in `disabled`. $( if !$c { - let mask = 1 << $crate::count!( $($skip)* ); + let mask: util::Mask = 1 << $crate::count!( $($skip)* ); disabled |= mask; } )* diff --git a/src/net/mod.rs b/src/net/mod.rs index 2f17f9e..0b8c1ec 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -46,3 +46,7 @@ cfg_net_unix! { pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; } + +cfg_net_windows! { + pub mod windows; +} diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index 4bcbe3f..02cb637 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -482,6 +482,48 @@ impl TcpSocket { let mio = self.inner.listen(backlog)?; TcpListener::new(mio) } + + /// Converts a [`std::net::TcpStream`] into a `TcpSocket`. The provided + /// socket must not have been connected prior to calling this function. This + /// function is typically used together with crates such as [`socket2`] to + /// configure socket options that are not available on `TcpSocket`. + /// + /// [`std::net::TcpStream`]: struct@std::net::TcpStream + /// [`socket2`]: https://docs.rs/socket2/ + /// + /// # Examples + /// + /// ``` + /// use tokio::net::TcpSocket; + /// use socket2::{Domain, Socket, Type}; + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// + /// let socket2_socket = Socket::new(Domain::IPV4, Type::STREAM, None)?; + /// + /// let socket = TcpSocket::from_std_stream(socket2_socket.into()); + /// + /// Ok(()) + /// } + /// ``` + pub fn from_std_stream(std_stream: std::net::TcpStream) -> TcpSocket { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = std_stream.into_raw_fd(); + unsafe { TcpSocket::from_raw_fd(raw_fd) } + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let raw_socket = std_stream.into_raw_socket(); + unsafe { TcpSocket::from_raw_socket(raw_socket) } + } + } } impl fmt::Debug for TcpSocket { diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs index 5c7c198..b95a8f6 100644 --- a/src/net/unix/ucred.rs +++ b/src/net/unix/ucred.rs @@ -73,7 +73,7 @@ pub(crate) mod impl_linux { // These paranoid checks should be optimized-out assert!(mem::size_of::<u32>() <= mem::size_of::<usize>()); - assert!(ucred_size <= u32::max_value() as usize); + assert!(ucred_size <= u32::MAX as usize); let mut ucred_size = ucred_size as socklen_t; diff --git a/src/net/windows/mod.rs b/src/net/windows/mod.rs new file mode 100644 index 0000000..060b68e --- /dev/null +++ b/src/net/windows/mod.rs @@ -0,0 +1,3 @@ +//! Windows specific network types. + +pub mod named_pipe; diff --git a/src/net/windows/named_pipe.rs b/src/net/windows/named_pipe.rs new file mode 100644 index 0000000..8013d6f --- /dev/null +++ b/src/net/windows/named_pipe.rs @@ -0,0 +1,1199 @@ +//! Tokio support for [Windows named pipes]. +//! +//! [Windows named pipes]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes + +use std::ffi::c_void; +use std::ffi::OsStr; +use std::io; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf}; +use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; + +// Hide imports which are not used when generating documentation. +#[cfg(not(docsrs))] +mod doc { + pub(super) use crate::os::windows::ffi::OsStrExt; + pub(super) use crate::winapi::shared::minwindef::{DWORD, FALSE}; + pub(super) use crate::winapi::um::fileapi; + pub(super) use crate::winapi::um::handleapi; + pub(super) use crate::winapi::um::namedpipeapi; + pub(super) use crate::winapi::um::winbase; + pub(super) use crate::winapi::um::winnt; + + pub(super) use mio::windows as mio_windows; +} + +// NB: none of these shows up in public API, so don't document them. +#[cfg(docsrs)] +mod doc { + pub type DWORD = crate::doc::NotDefinedHere; + + pub(super) mod mio_windows { + pub type NamedPipe = crate::doc::NotDefinedHere; + } +} + +use self::doc::*; + +/// A [Windows named pipe] server. +/// +/// Accepting client connections involves creating a server with +/// [`ServerOptions::create`] and waiting for clients to connect using +/// [`NamedPipeServer::connect`]. +/// +/// To avoid having clients sporadically fail with +/// [`std::io::ErrorKind::NotFound`] when they connect to a server, we must +/// ensure that at least one server instance is available at all times. This +/// means that the typical listen loop for a server is a bit involved, because +/// we have to ensure that we never drop a server accidentally while a client +/// might connect. +/// +/// So a correctly implemented server looks like this: +/// +/// ```no_run +/// use std::io; +/// use tokio::net::windows::named_pipe::ServerOptions; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-server"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// // The first server needs to be constructed early so that clients can +/// // be correctly connected. Otherwise calling .wait will cause the client to +/// // error. +/// // +/// // Here we also make use of `first_pipe_instance`, which will ensure that +/// // there are no other servers up and running already. +/// let mut server = ServerOptions::new() +/// .first_pipe_instance(true) +/// .create(PIPE_NAME)?; +/// +/// // Spawn the server loop. +/// let server = tokio::spawn(async move { +/// loop { +/// // Wait for a client to connect. +/// let connected = server.connect().await?; +/// +/// // Construct the next server to be connected before sending the one +/// // we already have of onto a task. This ensures that the server +/// // isn't closed (after it's done in the task) before a new one is +/// // available. Otherwise the client might error with +/// // `io::ErrorKind::NotFound`. +/// server = ServerOptions::new().create(PIPE_NAME)?; +/// +/// let client = tokio::spawn(async move { +/// /* use the connected client */ +/// # Ok::<_, std::io::Error>(()) +/// }); +/// # if true { break } // needed for type inference to work +/// } +/// +/// Ok::<_, io::Error>(()) +/// }); +/// +/// /* do something else not server related here */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeServer { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeServer { + /// Construct a new named pipe server from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the server is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{PipeEnd, PipeMode, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .pipe_mode(PipeMode::Message) + /// .max_instances(5) + /// .create(PIPE_NAME)?; + /// + /// let server_info = server.info()?; + /// + /// assert_eq!(server_info.end, PipeEnd::Server); + /// assert_eq!(server_info.mode, PipeMode::Message); + /// assert_eq!(server_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } + + /// Enables a named pipe server process to wait for a client process to + /// connect to an instance of a named pipe. A client process connects by + /// creating a named pipe with the same name. + /// + /// This corresponds to the [`ConnectNamedPipe`] system call. + /// + /// [`ConnectNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-connectnamedpipe + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let pipe = ServerOptions::new().create(PIPE_NAME)?; + /// + /// // Wait for a client to connect. + /// pipe.connect().await?; + /// + /// // Use the connected client... + /// # Ok(()) } + /// ``` + pub async fn connect(&self) -> io::Result<()> { + loop { + match self.io.connect() { + Ok(()) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.registration().readiness(Interest::WRITABLE).await?; + } + Err(e) => return Err(e), + } + } + + Ok(()) + } + + /// Disconnects the server end of a named pipe instance from a client + /// process. + /// + /// ``` + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-disconnect"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// // Wait for a client to become connected. + /// server.connect().await?; + /// + /// // Forcibly disconnect the client. + /// server.disconnect()?; + /// + /// // Write fails with an OS-specific error after client has been + /// // disconnected. + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_NOT_CONNECTED as i32)); + /// # Ok(()) } + /// ``` + pub fn disconnect(&self) -> io::Result<()> { + self.io.disconnect() + } +} + +impl AsyncRead for NamedPipeServer { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeServer { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeServer { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +/// A [Windows named pipe] client. +/// +/// Constructed using [`ClientOptions::open`]. +/// +/// Connecting a client correctly involves a few steps. When connecting through +/// [`ClientOptions::open`], it might error indicating one of two things: +/// +/// * [`std::io::ErrorKind::NotFound`] - There is no server available. +/// * [`ERROR_PIPE_BUSY`] - There is a server available, but it is busy. Sleep +/// for a while and try again. +/// +/// So a correctly implemented client looks like this: +/// +/// ```no_run +/// use std::time::Duration; +/// use tokio::net::windows::named_pipe::ClientOptions; +/// use tokio::time; +/// use winapi::shared::winerror; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-client"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// let client = loop { +/// match ClientOptions::new().open(PIPE_NAME) { +/// Ok(client) => break client, +/// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), +/// Err(e) => return Err(e), +/// } +/// +/// time::sleep(Duration::from_millis(50)).await; +/// }; +/// +/// /* use the connected client */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeClient { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeClient { + /// Construct a new named pipe client from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the client is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{ClientOptions, PipeEnd, PipeMode}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// let client_info = client.info()?; + /// + /// assert_eq!(client_info.end, PipeEnd::Client); + /// assert_eq!(client_info.mode, PipeMode::Message); + /// assert_eq!(client_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } +} + +impl AsyncRead for NamedPipeClient { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeClient { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeClient { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +// Helper to set a boolean flag as a bitfield. +macro_rules! bool_flag { + ($f:expr, $t:expr, $flag:expr) => {{ + let current = $f; + + if $t { + $f = current | $flag; + } else { + $f = current & !$flag; + }; + }}; +} + +/// A builder structure for construct a named pipe with named pipe-specific +/// options. This is required to use for named pipe servers who wants to modify +/// pipe-related options. +/// +/// See [`ServerOptions::create`]. +#[derive(Debug, Clone)] +pub struct ServerOptions { + open_mode: DWORD, + pipe_mode: DWORD, + max_instances: DWORD, + out_buffer_size: DWORD, + in_buffer_size: DWORD, + default_timeout: DWORD, +} + +impl ServerOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> ServerOptions { + ServerOptions { + open_mode: winbase::PIPE_ACCESS_DUPLEX | winbase::FILE_FLAG_OVERLAPPED, + pipe_mode: winbase::PIPE_TYPE_BYTE | winbase::PIPE_REJECT_REMOTE_CLIENTS, + max_instances: winbase::PIPE_UNLIMITED_INSTANCES, + out_buffer_size: 65536, + in_buffer_size: 65536, + default_timeout: 0, + } + } + + /// The pipe mode. + /// + /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for + /// documentation of what each mode means. + /// + /// This corresponding to specifying [`dwPipeMode`]. + /// + /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { + self.pipe_mode = match pipe_mode { + PipeMode::Byte => winbase::PIPE_TYPE_BYTE, + PipeMode::Message => winbase::PIPE_TYPE_MESSAGE, + }; + + self + } + + /// The flow of data in the pipe goes from client to server only. + /// + /// This corresponds to setting [`PIPE_ACCESS_INBOUND`]. + /// + /// [`PIPE_ACCESS_INBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_inbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying inbound access, client errors + /// with [`std::io::ErrorKind::PermissionDenied`] when attempting to create + /// the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let _server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling writing allows a client to connect, but errors with + /// [`std::io::ErrorKind::PermissionDenied`] if a write is attempted. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports server-to-client + /// communication. + /// + /// ``` + /// use std::io; + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = server.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = client.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_inbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_INBOUND); + self + } + + /// The flow of data in the pipe goes from server to client only. + /// + /// This corresponds to setting [`PIPE_ACCESS_OUTBOUND`]. + /// + /// [`PIPE_ACCESS_OUTBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_outbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying outbound access, client + /// errors with [`std::io::ErrorKind::PermissionDenied`] when attempting to + /// create the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling reading allows a client to connect, but attempting to read + /// will error with [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncReadExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let mut buf = [0u8; 4]; + /// let e = client.read(&mut buf).await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports client-to-server + /// communication. + /// + /// ``` + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = client.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = server.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// println!("done reading and writing"); + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_outbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_OUTBOUND); + self + } + + /// If you attempt to create multiple instances of a pipe with this flag + /// set, creation of the first server instance succeeds, but creation of any + /// subsequent instances will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// This option is intended to be used with servers that want to ensure that + /// they are the only process listening for clients on a given named pipe. + /// This is accomplished by enabling it for the first server instance + /// created in a process. + /// + /// This corresponds to setting [`FILE_FLAG_FIRST_PIPE_INSTANCE`]. + /// + /// # Errors + /// + /// If this option is set and more than one instance of the server for a + /// given named pipe exists, calling [`create`] will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance-error"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server1 = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME)?; + /// + /// // Second server errs, since it's not the first instance. + /// let e = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut builder = ServerOptions::new(); + /// builder.first_pipe_instance(true); + /// + /// let server = builder.create(PIPE_NAME)?; + /// let e = builder.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// drop(server); + /// + /// // OK: since, we've closed the other instance. + /// let _server2 = builder.create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + /// + /// [`create`]: ServerOptions::create + /// [`FILE_FLAG_FIRST_PIPE_INSTANCE`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_first_pipe_instance + pub fn first_pipe_instance(&mut self, first: bool) -> &mut Self { + bool_flag!( + self.open_mode, + first, + winbase::FILE_FLAG_FIRST_PIPE_INSTANCE + ); + self + } + + /// Indicates whether this server can accept remote clients or not. Remote + /// clients are disabled by default. + /// + /// This corresponds to setting [`PIPE_REJECT_REMOTE_CLIENTS`]. + /// + /// [`PIPE_REJECT_REMOTE_CLIENTS`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_reject_remote_clients + pub fn reject_remote_clients(&mut self, reject: bool) -> &mut Self { + bool_flag!(self.pipe_mode, reject, winbase::PIPE_REJECT_REMOTE_CLIENTS); + self + } + + /// The maximum number of instances that can be created for this pipe. The + /// first instance of the pipe can specify this value; the same number must + /// be specified for other instances of the pipe. Acceptable values are in + /// the range 1 through 254. The default value is unlimited. + /// + /// This corresponds to specifying [`nMaxInstances`]. + /// + /// [`nMaxInstances`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// The same numbers of `max_instances` have to be used by all servers. Any + /// additional servers trying to be built which uses a mismatching value + /// might error. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-max-instances"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new(); + /// server.max_instances(2); + /// + /// let s1 = server.create(PIPE_NAME)?; + /// let c1 = ClientOptions::new().open(PIPE_NAME); + /// + /// let s2 = server.create(PIPE_NAME)?; + /// let c2 = ClientOptions::new().open(PIPE_NAME); + /// + /// // Too many servers! + /// let e = server.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// + /// // Still too many servers even if we specify a higher value! + /// let e = server.max_instances(100).create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// # Ok(()) } + /// ``` + /// + /// # Panics + /// + /// This function will panic if more than 254 instances are specified. If + /// you do not wish to set an instance limit, leave it unspecified. + /// + /// ```should_panic + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let builder = ServerOptions::new().max_instances(255); + /// # Ok(()) } + /// ``` + pub fn max_instances(&mut self, instances: usize) -> &mut Self { + assert!(instances < 255, "cannot specify more than 254 instances"); + self.max_instances = instances as DWORD; + self + } + + /// The number of bytes to reserve for the output buffer. + /// + /// This corresponds to specifying [`nOutBufferSize`]. + /// + /// [`nOutBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn out_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.out_buffer_size = buffer as DWORD; + self + } + + /// The number of bytes to reserve for the input buffer. + /// + /// This corresponds to specifying [`nInBufferSize`]. + /// + /// [`nInBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn in_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.in_buffer_size = buffer as DWORD; + self + } + + /// Create the named pipe identified by `addr` for use as a server. + /// + /// This uses the [`CreateNamedPipe`] function. + /// + /// [`CreateNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Examples + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-create"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn create(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeServer> { + // Safety: We're calling create_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.create_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Create the named pipe identified by `addr` for use as a server. + /// + /// This is the same as [`create`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`create`] method. + /// + /// [`create`]: ServerOptions::create + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn create_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeServer> { + let addr = encode_addr(addr); + + let h = namedpipeapi::CreateNamedPipeW( + addr.as_ptr(), + self.open_mode, + self.pipe_mode, + self.max_instances, + self.out_buffer_size, + self.in_buffer_size, + self.default_timeout, + attrs as *mut _, + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeServer::from_raw_handle(h) + } +} + +/// A builder suitable for building and interacting with named pipes from the +/// client side. +/// +/// See [`ClientOptions::open`]. +#[derive(Debug, Clone)] +pub struct ClientOptions { + desired_access: DWORD, + security_qos_flags: DWORD, +} + +impl ClientOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// // Server must be created in order for the client creation to succeed. + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// let client = ClientOptions::new().open(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> Self { + Self { + desired_access: winnt::GENERIC_READ | winnt::GENERIC_WRITE, + security_qos_flags: winbase::SECURITY_IDENTIFICATION | winbase::SECURITY_SQOS_PRESENT, + } + } + + /// If the client supports reading data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_READ`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_READ`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn read(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_READ); + self + } + + /// If the created pipe supports writing data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_WRITE`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_WRITE`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn write(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_WRITE); + self + } + + /// Sets qos flags which are combined with other flags and attributes in the + /// call to [`CreateFile`]. + /// + /// By default `security_qos_flags` is set to [`SECURITY_IDENTIFICATION`], + /// calling this function would override that value completely with the + /// argument specified. + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`SECURITY_IDENTIFICATION`]: crate::winapi::um::winbase::SECURITY_IDENTIFICATION + /// [Impersonation Levels]: https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + pub fn security_qos_flags(&mut self, flags: u32) -> &mut Self { + // See: https://github.com/rust-lang/rust/pull/58216 + self.security_qos_flags = flags | winbase::SECURITY_SQOS_PRESENT; + self + } + + /// Open the named pipe identified by `addr`. + /// + /// This opens the client using [`CreateFile`] with the + /// `dwCreationDisposition` option set to `OPEN_EXISTING`. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// There are a few errors you need to take into account when creating a + /// named pipe on the client side: + /// + /// * [`std::io::ErrorKind::NotFound`] - This indicates that the named pipe + /// does not exist. Presumably the server is not up. + /// * [`ERROR_PIPE_BUSY`] - This error is raised when the named pipe exists, + /// but the server is not currently waiting for a connection. Please see the + /// examples for how to check for this error. + /// + /// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY + /// [`winapi`]: crate::winapi + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// [Tokio Runtime]: crate::runtime::Runtime + /// + /// A connect loop that waits until a socket becomes available looks like + /// this: + /// + /// ```no_run + /// use std::time::Duration; + /// use tokio::net::windows::named_pipe::ClientOptions; + /// use tokio::time; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = loop { + /// match ClientOptions::new().open(PIPE_NAME) { + /// Ok(client) => break client, + /// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + /// Err(e) => return Err(e), + /// } + /// + /// time::sleep(Duration::from_millis(50)).await; + /// }; + /// + /// // use the connected client. + /// # Ok(()) } + /// ``` + pub fn open(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeClient> { + // Safety: We're calling open_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.open_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Open the named pipe identified by `addr`. + /// + /// This is the same as [`open`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`open`] method. + /// + /// [`open`]: ClientOptions::open + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn open_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeClient> { + let addr = encode_addr(addr); + + // NB: We could use a platform specialized `OpenOptions` here, but since + // we have access to winapi it ultimately doesn't hurt to use + // `CreateFile` explicitly since it allows the use of our already + // well-structured wide `addr` to pass into CreateFileW. + let h = fileapi::CreateFileW( + addr.as_ptr(), + self.desired_access, + 0, + attrs as *mut _, + fileapi::OPEN_EXISTING, + self.get_flags(), + ptr::null_mut(), + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeClient::from_raw_handle(h) + } + + fn get_flags(&self) -> u32 { + self.security_qos_flags | winbase::FILE_FLAG_OVERLAPPED + } +} + +/// The pipe mode of a named pipe. +/// +/// Set through [`ServerOptions::pipe_mode`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeMode { + /// Data is written to the pipe as a stream of bytes. The pipe does not + /// distinguish bytes written during different write operations. + /// + /// Corresponds to [`PIPE_TYPE_BYTE`][crate::winapi::um::winbase::PIPE_TYPE_BYTE]. + Byte, + /// Data is written to the pipe as a stream of messages. The pipe treats the + /// bytes written during each write operation as a message unit. Any reading + /// on a named pipe returns [`ERROR_MORE_DATA`] when a message is not read + /// completely. + /// + /// Corresponds to [`PIPE_TYPE_MESSAGE`][crate::winapi::um::winbase::PIPE_TYPE_MESSAGE]. + /// + /// [`ERROR_MORE_DATA`]: crate::winapi::shared::winerror::ERROR_MORE_DATA + Message, +} + +/// Indicates the end of a named pipe. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeEnd { + /// The named pipe refers to the client end of a named pipe instance. + /// + /// Corresponds to [`PIPE_CLIENT_END`][crate::winapi::um::winbase::PIPE_CLIENT_END]. + Client, + /// The named pipe refers to the server end of a named pipe instance. + /// + /// Corresponds to [`PIPE_SERVER_END`][crate::winapi::um::winbase::PIPE_SERVER_END]. + Server, +} + +/// Information about a named pipe. +/// +/// Constructed through [`NamedPipeServer::info`] or [`NamedPipeClient::info`]. +#[derive(Debug)] +#[non_exhaustive] +pub struct PipeInfo { + /// Indicates the mode of a named pipe. + pub mode: PipeMode, + /// Indicates the end of a named pipe. + pub end: PipeEnd, + /// The maximum number of instances that can be created for this pipe. + pub max_instances: u32, + /// The number of bytes to reserve for the output buffer. + pub out_buffer_size: u32, + /// The number of bytes to reserve for the input buffer. + pub in_buffer_size: u32, +} + +/// Encode an address so that it is a null-terminated wide string. +fn encode_addr(addr: impl AsRef<OsStr>) -> Box<[u16]> { + let len = addr.as_ref().encode_wide().count(); + let mut vec = Vec::with_capacity(len + 1); + vec.extend(addr.as_ref().encode_wide()); + vec.push(0); + vec.into_boxed_slice() +} + +/// Internal function to get the info out of a raw named pipe. +unsafe fn named_pipe_info(handle: RawHandle) -> io::Result<PipeInfo> { + let mut flags = 0; + let mut out_buffer_size = 0; + let mut in_buffer_size = 0; + let mut max_instances = 0; + + let result = namedpipeapi::GetNamedPipeInfo( + handle, + &mut flags, + &mut out_buffer_size, + &mut in_buffer_size, + &mut max_instances, + ); + + if result == FALSE { + return Err(io::Error::last_os_error()); + } + + let mut end = PipeEnd::Client; + let mut mode = PipeMode::Byte; + + if flags & winbase::PIPE_SERVER_END != 0 { + end = PipeEnd::Server; + } + + if flags & winbase::PIPE_TYPE_MESSAGE != 0 { + mode = PipeMode::Message; + } + + Ok(PipeInfo { + end, + mode, + out_buffer_size, + in_buffer_size, + max_instances, + }) +} diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs index ffe0bca..13dfb69 100644 --- a/src/runtime/basic_scheduler.rs +++ b/src/runtime/basic_scheduler.rs @@ -84,13 +84,13 @@ unsafe impl Send for Entry {} /// Scheduler state shared between threads. struct Shared { - /// Remote run queue - queue: Mutex<VecDeque<Entry>>, + /// Remote run queue. None if the `Runtime` has been dropped. + queue: Mutex<Option<VecDeque<Entry>>>, - /// Unpark the blocked thread + /// Unpark the blocked thread. unpark: Box<dyn Unpark>, - // indicates whether the blocked on thread was woken + /// Indicates whether the blocked on thread was woken. woken: AtomicBool, } @@ -124,7 +124,7 @@ impl<P: Park> BasicScheduler<P> { let spawner = Spawner { shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), unpark: unpark as Box<dyn Unpark>, woken: AtomicBool::new(false), }), @@ -351,18 +351,29 @@ impl<P: Park> Drop for BasicScheduler<P> { task.shutdown(); } - // Drain remote queue - for entry in scheduler.spawner.shared.queue.lock().drain(..) { - match entry { - Entry::Schedule(task) => { - task.shutdown(); - } - Entry::Release(..) => { - // Do nothing, each entry in the linked list was *just* - // dropped by the scheduler above. + // Drain remote queue and set it to None + let mut remote_queue = scheduler.spawner.shared.queue.lock(); + + // Using `Option::take` to replace the shared queue with `None`. + if let Some(remote_queue) = remote_queue.take() { + for entry in remote_queue { + match entry { + Entry::Schedule(task) => { + task.shutdown(); + } + Entry::Release(..) => { + // Do nothing, each entry in the linked list was *just* + // dropped by the scheduler above. + } } } } + // By dropping the mutex lock after the full duration of the above loop, + // any thread that sees the queue in the `None` state is guaranteed that + // the runtime has fully shut down. + // + // The assert below is unrelated to this mutex. + drop(remote_queue); assert!(context.tasks.borrow().owned.is_empty()); }); @@ -381,7 +392,7 @@ impl Spawner { /// Spawns a future onto the thread pool pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where - F: Future + Send + 'static, + F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { let (task, handle) = task::joinable(future); @@ -390,7 +401,10 @@ impl Spawner { } fn pop(&self) -> Option<Entry> { - self.shared.queue.lock().pop_front() + match self.shared.queue.lock().as_mut() { + Some(queue) => queue.pop_front(), + None => None, + } } fn waker_ref(&self) -> WakerRef<'_> { @@ -429,7 +443,19 @@ impl Schedule for Arc<Shared> { // safety: the task is inserted in the list in `bind`. unsafe { cx.tasks.borrow_mut().owned.remove(ptr) } } else { - self.queue.lock().push_back(Entry::Release(ptr)); + // By sending an `Entry::Release` to the runtime, we ask the + // runtime to remove this task from the linked list in + // `Tasks::owned`. + // + // If the queue is `None`, then the task was already removed + // from that list in the destructor of `BasicScheduler`. We do + // not do anything in this case for the same reason that + // `Entry::Release` messages are ignored in the remote queue + // drain loop of `BasicScheduler`'s destructor. + if let Some(queue) = self.queue.lock().as_mut() { + queue.push_back(Entry::Release(ptr)); + } + self.unpark.unpark(); // Returning `None` here prevents the task plumbing from being // freed. It is then up to the scheduler through the queue we @@ -445,8 +471,17 @@ impl Schedule for Arc<Shared> { cx.tasks.borrow_mut().queue.push_back(task); } _ => { - self.queue.lock().push_back(Entry::Schedule(task)); - self.unpark.unpark(); + let mut guard = self.queue.lock(); + if let Some(queue) = guard.as_mut() { + queue.push_back(Entry::Schedule(task)); + drop(guard); + self.unpark.unpark(); + } else { + // The runtime has shut down. We drop the new task + // immediately. + drop(guard); + task.shutdown(); + } } }); } diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs index 5c9b8ed..b7d7251 100644 --- a/src/runtime/blocking/pool.rs +++ b/src/runtime/blocking/pool.rs @@ -4,7 +4,6 @@ use crate::loom::sync::{Arc, Condvar, Mutex}; use crate::loom::thread; use crate::runtime::blocking::schedule::NoopSchedule; use crate::runtime::blocking::shutdown; -use crate::runtime::blocking::task::BlockingTask; use crate::runtime::builder::ThreadNameFn; use crate::runtime::context; use crate::runtime::task::{self, JoinHandle}; @@ -86,18 +85,6 @@ where rt.spawn_blocking(func) } -#[allow(dead_code)] -pub(crate) fn try_spawn_blocking<F, R>(func: F) -> Result<(), ()> -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - let rt = context::current().expect(CONTEXT_MISSING_ERROR); - - let (task, _handle) = task::joinable(BlockingTask::new(func)); - rt.blocking_spawner.spawn(task, &rt) -} - // ===== impl BlockingPool ===== impl BlockingPool { diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs index 4f1b4c5..173f0ca 100644 --- a/src/runtime/handle.rs +++ b/src/runtime/handle.rs @@ -174,8 +174,11 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + let fut = BlockingTask::new(func); + #[cfg(all(tokio_unstable, feature = "tracing"))] - let func = { + let fut = { + use tracing::Instrument; #[cfg(tokio_track_caller)] let location = std::panic::Location::caller(); #[cfg(tokio_track_caller)] @@ -193,12 +196,9 @@ impl Handle { kind = %"blocking", function = %std::any::type_name::<F>(), ); - move || { - let _g = span.enter(); - func() - } + fut.instrument(span) }; - let (task, handle) = task::joinable(BlockingTask::new(func)); + let (task, handle) = task::joinable(fut); let _ = self.blocking_spawner.spawn(task, &self); handle } diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs index 6ea23c9..3df7bba 100644 --- a/src/runtime/queue.rs +++ b/src/runtime/queue.rs @@ -109,7 +109,10 @@ impl<T> Local<T> { } /// Pushes a task to the back of the local queue, skipping the LIFO slot. - pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) { + pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) + where + T: crate::runtime::task::Schedule, + { let tail = loop { let head = self.inner.head.load(Acquire); let (steal, real) = unpack(head); @@ -121,9 +124,14 @@ impl<T> Local<T> { // There is capacity for the task break tail; } else if steal != real { - // Concurrently stealing, this will free up capacity, so - // only push the new task onto the inject queue - inject.push(task); + // Concurrently stealing, this will free up capacity, so only + // push the new task onto the inject queue + // + // If the task failes to be pushed on the injection queue, there + // is nothing to be done at this point as the task cannot be a + // newly spawned task. Shutting down this task is handled by the + // worker shutdown process. + let _ = inject.push(task); return; } else { // Push the current task and half of the queue into the @@ -504,16 +512,19 @@ impl<T: 'static> Inject<T> { } /// Pushes a value into the queue. - pub(super) fn push(&self, task: task::Notified<T>) { + /// + /// Returns `Err(task)` if pushing fails due to the queue being shutdown. + /// The caller is expected to call `shutdown()` on the task **if and only + /// if** it is a newly spawned task. + pub(super) fn push(&self, task: task::Notified<T>) -> Result<(), task::Notified<T>> + where + T: crate::runtime::task::Schedule, + { // Acquire queue lock let mut p = self.pointers.lock(); if p.is_closed { - // Drop the mutex to avoid a potential deadlock when - // re-entering. - drop(p); - drop(task); - return; + return Err(task); } // safety: only mutated with the lock held @@ -532,6 +543,7 @@ impl<T: 'static> Inject<T> { p.tail = Some(task); self.len.store(len + 1, Release); + Ok(()) } pub(super) fn push_batch( @@ -617,7 +629,7 @@ fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { /// Split the head value into the real head and the index a stealer is working /// on. fn unpack(n: u32) -> (u16, u16) { - let real = n & u16::max_value() as u32; + let real = n & u16::MAX as u32; let steal = n >> 16; (steal as u16, real as u16) @@ -630,5 +642,5 @@ fn pack(steal: u16, real: u16) -> u32 { #[test] fn test_local_queue_capacity() { - assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::max_value() as usize); + assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize); } diff --git a/src/runtime/spawner.rs b/src/runtime/spawner.rs index a37c667..fbcde2c 100644 --- a/src/runtime/spawner.rs +++ b/src/runtime/spawner.rs @@ -1,8 +1,7 @@ cfg_rt! { + use crate::future::Future; use crate::runtime::basic_scheduler; use crate::task::JoinHandle; - - use std::future::Future; } cfg_rt_multi_thread! { diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs index fb6dafd..026a6dc 100644 --- a/src/runtime/task/core.rs +++ b/src/runtime/task/core.rs @@ -9,13 +9,13 @@ //! Make sure to consult the relevant safety section of each function before //! use. +use crate::future::Future; use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; use crate::runtime::task::{Notified, Schedule, Task}; use crate::util::linked_list; -use std::future::Future; use std::pin::Pin; use std::ptr::NonNull; use std::task::{Context, Poll, Waker}; @@ -71,6 +71,10 @@ pub(crate) struct Header { /// Table of function pointers for executing actions on the task. pub(super) vtable: &'static Vtable, + + /// The tracing ID for this instrumented task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) id: Option<tracing::Id>, } unsafe impl Send for Header {} @@ -93,6 +97,8 @@ impl<T: Future, S: Schedule> Cell<T, S> { /// Allocates a new task cell, containing the header, trailer, and core /// structures. pub(super) fn new(future: T, state: State) -> Box<Cell<T, S>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let id = future.id(); Box::new(Cell { header: Header { state, @@ -100,6 +106,8 @@ impl<T: Future, S: Schedule> Cell<T, S> { queue_next: UnsafeCell::new(None), stack_next: UnsafeCell::new(None), vtable: raw::vtable::<T, S>(), + #[cfg(all(tokio_unstable, feature = "tracing"))] + id, }, core: Core { scheduler: Scheduler { diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs index 7d596e3..47bbcc1 100644 --- a/src/runtime/task/harness.rs +++ b/src/runtime/task/harness.rs @@ -1,9 +1,9 @@ +use crate::future::Future; use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer}; use crate::runtime::task::state::Snapshot; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{JoinError, Notified, Schedule, Task}; -use std::future::Future; use std::mem; use std::panic; use std::ptr::NonNull; @@ -146,6 +146,11 @@ where } } + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) fn id(&self) -> Option<&tracing::Id> { + self.header().id.as_ref() + } + /// Forcibly shutdown the task /// /// Attempt to transition to `Running` in order to forcibly shutdown the diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 7b49e95..58b8c2a 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -26,9 +26,9 @@ cfg_rt_multi_thread! { pub(crate) use self::stack::TransferStack; } +use crate::future::Future; use crate::util::linked_list; -use std::future::Future; use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; diff --git a/src/runtime/task/raw.rs b/src/runtime/task/raw.rs index cae56d0..a9cd4e6 100644 --- a/src/runtime/task/raw.rs +++ b/src/runtime/task/raw.rs @@ -1,6 +1,6 @@ +use crate::future::Future; use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; -use std::future::Future; use std::ptr::NonNull; use std::task::{Poll, Waker}; diff --git a/src/runtime/task/state.rs b/src/runtime/task/state.rs index 21e9043..1f08d6d 100644 --- a/src/runtime/task/state.rs +++ b/src/runtime/task/state.rs @@ -29,12 +29,15 @@ const LIFECYCLE_MASK: usize = 0b11; const NOTIFIED: usize = 0b100; /// The join handle is still around +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const JOIN_INTEREST: usize = 0b1_000; /// A join handle waker has been set +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const JOIN_WAKER: usize = 0b10_000; /// The task has been forcibly cancelled. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const CANCELLED: usize = 0b100_000; /// All bits @@ -52,7 +55,7 @@ const REF_ONE: usize = 1 << REF_COUNT_SHIFT; /// State a task is initialized with /// /// A task is initialized with two references: one for the scheduler and one for -/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTERST` is +/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is /// set. A new task is immediately pushed into the run queue for execution and /// starts with the `NOTIFIED` flag set. const INITIAL_STATE: usize = (REF_ONE * 2) | JOIN_INTEREST | NOTIFIED; @@ -64,7 +67,7 @@ impl State { pub(super) fn new() -> State { // A task is initialized with three references: one for the scheduler, // one for the `JoinHandle`, one for the task handle made available in - // release. As the task starts with a `JoinHandle`, `JOIN_INTERST` is + // release. As the task starts with a `JoinHandle`, `JOIN_INTEREST` is // set. A new task is immediately pushed into the run queue for // execution and starts with the `NOTIFIED` flag set. State { @@ -306,7 +309,7 @@ impl State { let prev = self.val.fetch_add(REF_ONE, Relaxed); // If the reference count overflowed, abort. - if prev > isize::max_value() as usize { + if prev > isize::MAX as usize { process::abort(); } } @@ -410,7 +413,7 @@ impl Snapshot { } fn ref_inc(&mut self) { - assert!(self.0 <= isize::max_value() as usize); + assert!(self.0 <= isize::MAX as usize); self.0 += REF_ONE; } diff --git a/src/runtime/task/waker.rs b/src/runtime/task/waker.rs index 5c2d478..b7313b4 100644 --- a/src/runtime/task/waker.rs +++ b/src/runtime/task/waker.rs @@ -1,7 +1,7 @@ +use crate::future::Future; use crate::runtime::task::harness::Harness; use crate::runtime::task::{Header, Schedule}; -use std::future::Future; use std::marker::PhantomData; use std::mem::ManuallyDrop; use std::ops; @@ -44,12 +44,38 @@ impl<S> ops::Deref for WakerRef<'_, S> { } } +cfg_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + if let Some(id) = $harness.id() { + tracing::trace!( + target: "tokio::task::waker", + op = $op, + task.id = id.into_u64(), + ); + } + } + } +} + +cfg_not_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + // noop + let _ = &$harness; + } + } +} + unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker where T: Future, S: Schedule, { let header = ptr as *const Header; + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.clone"); (*header).state.ref_inc(); raw_waker::<T, S>(header) } @@ -61,6 +87,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.drop"); harness.drop_reference(); } @@ -71,6 +98,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake"); harness.wake_by_val(); } @@ -82,6 +110,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake_by_ref"); harness.wake_by_ref(); } diff --git a/src/runtime/tests/loom_shutdown_join.rs b/src/runtime/tests/loom_shutdown_join.rs new file mode 100644 index 0000000..6fbc4bf --- /dev/null +++ b/src/runtime/tests/loom_shutdown_join.rs @@ -0,0 +1,28 @@ +use crate::runtime::{Builder, Handle}; + +#[test] +fn join_handle_cancel_on_shutdown() { + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + builder.check(|| { + use futures::future::FutureExt; + + let rt = Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(); + + let handle = rt.block_on(async move { Handle::current() }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + }); +} diff --git a/src/runtime/tests/mod.rs b/src/runtime/tests/mod.rs index ebb48de..c84ba1b 100644 --- a/src/runtime/tests/mod.rs +++ b/src/runtime/tests/mod.rs @@ -4,6 +4,7 @@ cfg_loom! { mod loom_oneshot; mod loom_pool; mod loom_queue; + mod loom_shutdown_join; } cfg_not_loom! { diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs index a34526f..45a3e99 100644 --- a/src/runtime/tests/task.rs +++ b/src/runtime/tests/task.rs @@ -79,7 +79,7 @@ static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); impl Runtime { fn tick(&self) -> usize { - self.tick_max(usize::max_value()) + self.tick_max(usize::MAX) } fn tick_max(&self, max: usize) -> usize { diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs index 47f8ee3..96312d3 100644 --- a/src/runtime/thread_pool/mod.rs +++ b/src/runtime/thread_pool/mod.rs @@ -90,11 +90,17 @@ impl Spawner { /// Spawns a future onto the thread pool pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where - F: Future + Send + 'static, + F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { let (task, handle) = task::joinable(future); - self.shared.schedule(task, false); + + if let Err(task) = self.shared.schedule(task, false) { + // The newly spawned task could not be scheduled because the runtime + // is shutting down. The task must be explicitly shutdown at this point. + task.shutdown(); + } + handle } diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs index 86d3f91..70cbddb 100644 --- a/src/runtime/thread_pool/worker.rs +++ b/src/runtime/thread_pool/worker.rs @@ -709,16 +709,22 @@ impl task::Schedule for Arc<Worker> { } fn schedule(&self, task: Notified) { - self.shared.schedule(task, false); + // Because this is not a newly spawned task, if scheduling fails due to + // the runtime shutting down, there is no special work that must happen + // here. + let _ = self.shared.schedule(task, false); } fn yield_now(&self, task: Notified) { - self.shared.schedule(task, true); + // Because this is not a newly spawned task, if scheduling fails due to + // the runtime shutting down, there is no special work that must happen + // here. + let _ = self.shared.schedule(task, true); } } impl Shared { - pub(super) fn schedule(&self, task: Notified, is_yield: bool) { + pub(super) fn schedule(&self, task: Notified, is_yield: bool) -> Result<(), Notified> { CURRENT.with(|maybe_cx| { if let Some(cx) = maybe_cx { // Make sure the task is part of the **current** scheduler. @@ -726,15 +732,16 @@ impl Shared { // And the current thread still holds a core if let Some(core) = cx.core.borrow_mut().as_mut() { self.schedule_local(core, task, is_yield); - return; + return Ok(()); } } } // Otherwise, use the inject queue - self.inject.push(task); + self.inject.push(task)?; self.notify_parked(); - }); + Ok(()) + }) } fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) { @@ -823,7 +830,9 @@ impl Shared { } // Drain the injection queue - while self.inject.pop().is_some() {} + while let Some(task) = self.inject.pop() { + task.shutdown(); + } } fn ptr_eq(&self, other: &Shared) -> bool { diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 5f97c1a..457e6ab 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -428,6 +428,11 @@ //! bounding of any kind. cfg_sync! { + /// Named future types. + pub mod futures { + pub use super::notify::Notified; + } + mod barrier; pub use barrier::{Barrier, BarrierWaitResult}; diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index ce857d7..cfd8da0 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -65,7 +65,7 @@ pub struct Receiver<T> { /// with backpressure. /// /// The channel will buffer up to the provided number of messages. Once the -/// buffer is full, attempts to `send` new messages will wait until a message is +/// buffer is full, attempts to send new messages will wait until a message is /// received from the channel. The provided buffer capacity must be at least 1. /// /// All data sent on `Sender` will become available on `Receiver` in the same @@ -76,7 +76,7 @@ pub struct Receiver<T> { /// /// If the `Receiver` is disconnected while trying to `send`, the `send` method /// will return a `SendError`. Similarly, if `Sender` is disconnected while -/// trying to `recv`, the `recv` method will return a `RecvError`. +/// trying to `recv`, the `recv` method will return `None`. /// /// # Panics /// @@ -887,7 +887,7 @@ impl<T> Sender<T> { /// let permit = tx.reserve().await.unwrap(); /// assert_eq!(tx.capacity(), 4); /// - /// // Sending and receiving a value increases the caapcity by one. + /// // Sending and receiving a value increases the capacity by one. /// permit.send(()); /// rx.recv().await.unwrap(); /// assert_eq!(tx.capacity(), 5); diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index a2d2824..0d25ad3 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -55,14 +55,18 @@ impl<T> From<SendError<T>> for TrySendError<T> { /// Error returned by `Receiver`. #[derive(Debug)] +#[doc(hidden)] +#[deprecated(note = "This type is unused because recv returns an Option.")] pub struct RecvError(()); +#[allow(deprecated)] impl fmt::Display for RecvError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } +#[allow(deprecated)] impl Error for RecvError {} cfg_time! { diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 5d2132f..07be759 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -140,7 +140,7 @@ struct Waiter { _p: PhantomPinned, } -/// Future returned from `notified()` +/// Future returned from [`Notify::notified()`] #[derive(Debug)] pub struct Notified<'a> { /// The `Notify` being received on. diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index af75042..5d42d1c 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -24,7 +24,55 @@ use std::sync::Arc; /// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`] /// utility. /// +/// # Examples +/// +/// Basic usage: +/// +/// ``` +/// use tokio::sync::{Semaphore, TryAcquireError}; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Semaphore::new(3); +/// +/// let a_permit = semaphore.acquire().await.unwrap(); +/// let two_permits = semaphore.acquire_many(2).await.unwrap(); +/// +/// assert_eq!(semaphore.available_permits(), 0); +/// +/// let permit_attempt = semaphore.try_acquire(); +/// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); +/// } +/// ``` +/// +/// Use [`Semaphore::acquire_owned`] to move permits across tasks: +/// +/// ``` +/// use std::sync::Arc; +/// use tokio::sync::Semaphore; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Arc::new(Semaphore::new(3)); +/// let mut join_handles = Vec::new(); +/// +/// for _ in 0..5 { +/// let permit = semaphore.clone().acquire_owned().await.unwrap(); +/// join_handles.push(tokio::spawn(async move { +/// // perform task... +/// // explicitly own `permit` in the task +/// drop(permit); +/// })); +/// } +/// +/// for handle in join_handles { +/// handle.await.unwrap(); +/// } +/// } +/// ``` +/// /// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html +/// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned #[derive(Debug)] pub struct Semaphore { /// The low level semaphore @@ -79,6 +127,15 @@ impl Semaphore { } /// Creates a new semaphore with the initial number of permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// static SEM: Semaphore = Semaphore::const_new(10); + /// ``` + /// #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new(permits: usize) -> Self { @@ -105,6 +162,26 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// drop(permit_1); + /// assert_eq!(semaphore.available_permits(), 1); + /// } + /// ``` + /// /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { @@ -121,6 +198,20 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permits. /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(5); + /// + /// let permit = semaphore.acquire_many(3).await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 2); + /// } + /// ``` + /// /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { @@ -137,6 +228,25 @@ impl Semaphore { /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, /// this returns a [`SemaphorePermit`] representing the acquired permits. /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits /// [`SemaphorePermit`]: crate::sync::SemaphorePermit @@ -153,8 +263,24 @@ impl Semaphore { /// Tries to acquire `n` permits from the semaphore. /// /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] - /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, - /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. + /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(4); + /// + /// let permit_1 = semaphore.try_acquire_many(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` /// /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits @@ -176,6 +302,32 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(3)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_owned().await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`AcquireError`]: crate::sync::AcquireError /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit @@ -194,6 +346,32 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(10)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`AcquireError`]: crate::sync::AcquireError /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit @@ -216,6 +394,26 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(2)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire_owned(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits @@ -238,6 +436,23 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(4)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_many_owned(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many_owned(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs index 5917204..8616007 100644 --- a/src/sync/task/atomic_waker.rs +++ b/src/sync/task/atomic_waker.rs @@ -29,7 +29,7 @@ pub(crate) struct AtomicWaker { // `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell // stores a `Waker` value produced by calls to `register` and many threads can -// race to take the waker by calling `wake. +// race to take the waker by calling `wake`. // // If a new `Waker` instance is produced by calling `register` before an existing // one is consumed, then the existing one is overwritten. diff --git a/src/sync/watch.rs b/src/sync/watch.rs index db65e5a..42d417a 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -417,6 +417,28 @@ impl<T> Sender<T> { Receiver::from_shared(version, shared) } } + + /// Returns the number of receivers that currently exist + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx1) = watch::channel("hello"); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = rx1.clone(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + self.shared.ref_count_rx.load(Relaxed) + } } impl<T> Drop for Sender<T> { diff --git a/src/task/mod.rs b/src/task/mod.rs index 7255535..25dab0c 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -86,7 +86,7 @@ //! ``` //! //! Again, like `std::thread`'s [`JoinHandle` type][thread_join], if the spawned -//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]`. For +//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]. For //! example: //! //! ``` diff --git a/src/time/clock.rs b/src/time/clock.rs index c5ef86b..a0ff621 100644 --- a/src/time/clock.rs +++ b/src/time/clock.rs @@ -7,7 +7,7 @@ //! configurable. cfg_not_test_util! { - use crate::time::{Duration, Instant}; + use crate::time::{Instant}; #[derive(Debug, Clone)] pub(crate) struct Clock {} @@ -24,14 +24,6 @@ cfg_not_test_util! { pub(crate) fn now(&self) -> Instant { now() } - - pub(crate) fn is_paused(&self) -> bool { - false - } - - pub(crate) fn advance(&self, _dur: Duration) { - unreachable!(); - } } } @@ -121,10 +113,9 @@ cfg_test_util! { /// runtime. pub async fn advance(duration: Duration) { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); - let until = clock.now() + duration; clock.advance(duration); - crate::time::sleep_until(until).await; + crate::task::yield_now().await; } /// Return the current instant, factoring in frozen time. diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs index 08edab3..168e0b9 100644 --- a/src/time/driver/entry.rs +++ b/src/time/driver/entry.rs @@ -68,7 +68,7 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; type TimerResult = Result<(), crate::time::error::Error>; -const STATE_DEREGISTERED: u64 = u64::max_value(); +const STATE_DEREGISTERED: u64 = u64::MAX; const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; @@ -85,10 +85,10 @@ const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; /// requires only the driver lock. pub(super) struct StateCell { /// Holds either the scheduled expiration time for this timer, or (if the - /// timer has been fired and is unregistered), `u64::max_value()`. + /// timer has been fired and is unregistered), `u64::MAX`. state: AtomicU64, /// If the timer is fired (an Acquire order read on state shows - /// `u64::max_value()`), holds the result that should be returned from + /// `u64::MAX`), holds the result that should be returned from /// polling the timer. Otherwise, the contents are unspecified and reading /// without holding the driver lock is undefined behavior. result: UnsafeCell<TimerResult>, @@ -125,7 +125,7 @@ impl StateCell { fn when(&self) -> Option<u64> { let cur_state = self.state.load(Ordering::Relaxed); - if cur_state == u64::max_value() { + if cur_state == u64::MAX { None } else { Some(cur_state) @@ -271,7 +271,7 @@ impl StateCell { /// ordering, but is conservative - if it returns false, the timer is /// definitely _not_ registered. pub(super) fn might_be_registered(&self) -> bool { - self.state.load(Ordering::Relaxed) != u64::max_value() + self.state.load(Ordering::Relaxed) != u64::MAX } } @@ -591,7 +591,7 @@ impl TimerHandle { match self.inner.as_ref().state.mark_pending(not_after) { Ok(()) => { // mark this as being on the pending queue in cached_when - self.inner.as_ref().set_cached_when(u64::max_value()); + self.inner.as_ref().set_cached_when(u64::MAX); Ok(()) } Err(tick) => { diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs index 3eb1004..37d2231 100644 --- a/src/time/driver/mod.rs +++ b/src/time/driver/mod.rs @@ -91,6 +91,15 @@ pub(crate) struct Driver<P: Park + 'static> { /// Parker to delegate to park: P, + + // When `true`, a call to `park_timeout` should immediately return and time + // should not advance. One reason for this to be `true` is if the task + // passed to `Runtime::block_on` called `task::yield_now()`. + // + // While it may look racy, it only has any effect when the clock is paused + // and pausing the clock is restricted to a single-threaded runtime. + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, } /// A structure which handles conversion from Instants to u64 timestamps. @@ -178,6 +187,8 @@ where time_source, handle: Handle::new(Arc::new(inner)), park, + #[cfg(feature = "test-util")] + did_wake: Arc::new(AtomicBool::new(false)), } } @@ -192,8 +203,6 @@ where } fn park_internal(&mut self, limit: Option<Duration>) -> Result<(), P::Error> { - let clock = &self.time_source.clock; - let mut lock = self.handle.get().state.lock(); assert!(!self.handle.is_shutdown()); @@ -217,26 +226,14 @@ where duration = std::cmp::min(limit, duration); } - if clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - - // Simulate advancing time - clock.advance(duration); - } else { - self.park.park_timeout(duration)?; - } + self.park_timeout(duration)?; } else { self.park.park_timeout(Duration::from_secs(0))?; } } None => { if let Some(duration) = limit { - if clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - clock.advance(duration); - } else { - self.park.park_timeout(duration)?; - } + self.park_timeout(duration)?; } else { self.park.park()?; } @@ -248,6 +245,39 @@ where Ok(()) } + + cfg_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + let clock = &self.time_source.clock; + + if clock.is_paused() { + self.park.park_timeout(Duration::from_secs(0))?; + + // If the time driver was woken, then the park completed + // before the "duration" elapsed (usually caused by a + // yield in `Runtime::block_on`). In this case, we don't + // advance the clock. + if !self.did_wake() { + // Simulate advancing time + clock.advance(duration); + } + } else { + self.park.park_timeout(duration)?; + } + + Ok(()) + } + + fn did_wake(&self) -> bool { + self.did_wake.swap(false, Ordering::SeqCst) + } + } + + cfg_not_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + self.park.park_timeout(duration) + } + } } impl Handle { @@ -387,11 +417,11 @@ impl<P> Park for Driver<P> where P: Park + 'static, { - type Unpark = P::Unpark; + type Unpark = TimerUnpark<P>; type Error = P::Error; fn unpark(&self) -> Self::Unpark { - self.park.unpark() + TimerUnpark::new(self) } fn park(&mut self) -> Result<(), Self::Error> { @@ -426,6 +456,33 @@ where } } +pub(crate) struct TimerUnpark<P: Park + 'static> { + inner: P::Unpark, + + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, +} + +impl<P: Park + 'static> TimerUnpark<P> { + fn new(driver: &Driver<P>) -> TimerUnpark<P> { + TimerUnpark { + inner: driver.park.unpark(), + + #[cfg(feature = "test-util")] + did_wake: driver.did_wake.clone(), + } + } +} + +impl<P: Park + 'static> Unpark for TimerUnpark<P> { + fn unpark(&self) { + #[cfg(feature = "test-util")] + self.did_wake.store(true, Ordering::SeqCst); + + self.inner.unpark(); + } +} + // ===== impl Inner ===== impl Inner { diff --git a/src/time/driver/wheel/mod.rs b/src/time/driver/wheel/mod.rs index 24bf517..5a40f6d 100644 --- a/src/time/driver/wheel/mod.rs +++ b/src/time/driver/wheel/mod.rs @@ -119,7 +119,7 @@ impl Wheel { pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) { unsafe { let when = item.as_ref().cached_when(); - if when == u64::max_value() { + if when == u64::MAX { self.pending.remove(item); } else { debug_assert!( diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs index 480ea09..a74f562 100644 --- a/src/util/linked_list.rs +++ b/src/util/linked_list.rs @@ -50,6 +50,7 @@ pub(crate) unsafe trait Link { type Target; /// Convert the handle to a raw pointer without consuming the handle + #[allow(clippy::wrong_self_convention)] fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target>; /// Convert the raw pointer to a handle diff --git a/src/util/wake.rs b/src/util/wake.rs index 001577d..5773937 100644 --- a/src/util/wake.rs +++ b/src/util/wake.rs @@ -54,11 +54,7 @@ unsafe fn inc_ref_count<T: Wake>(data: *const ()) { let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T)); // Now increase refcount, but don't drop new refcount either - let arc_clone: ManuallyDrop<_> = arc.clone(); - - // Drop explicitly to avoid clippy warnings - drop(arc); - drop(arc_clone); + let _arc_clone: ManuallyDrop<_> = arc.clone(); } unsafe fn clone_arc_raw<T: Wake>(data: *const ()) -> RawWaker { |