aboutsummaryrefslogtreecommitdiff
path: root/src/io
diff options
context:
space:
mode:
Diffstat (limited to 'src/io')
-rw-r--r--src/io/async_read.rs162
-rw-r--r--src/io/async_seek.rs45
-rw-r--r--src/io/async_write.rs22
-rw-r--r--src/io/blocking.rs24
-rw-r--r--src/io/driver/mod.rs351
-rw-r--r--src/io/driver/ready.rs187
-rw-r--r--src/io/driver/scheduled_io.rs501
-rw-r--r--src/io/mod.rs38
-rw-r--r--src/io/poll_evented.rs337
-rw-r--r--src/io/read_buf.rs261
-rw-r--r--src/io/registration.rs286
-rw-r--r--src/io/seek.rs55
-rw-r--r--src/io/split.rs25
-rw-r--r--src/io/stderr.rs7
-rw-r--r--src/io/stdin.rs11
-rw-r--r--src/io/stdio_common.rs220
-rw-r--r--src/io/stdout.rs6
-rw-r--r--src/io/util/async_buf_read_ext.rs2
-rw-r--r--src/io/util/async_read_ext.rs14
-rw-r--r--src/io/util/async_seek_ext.rs108
-rw-r--r--src/io/util/async_write_ext.rs5
-rw-r--r--src/io/util/buf_reader.rs59
-rw-r--r--src/io/util/buf_stream.rs12
-rw-r--r--src/io/util/buf_writer.rs12
-rw-r--r--src/io/util/chain.rs24
-rw-r--r--src/io/util/copy.rs54
-rw-r--r--src/io/util/copy_buf.rs102
-rw-r--r--src/io/util/empty.rs11
-rw-r--r--src/io/util/flush.rs28
-rw-r--r--src/io/util/lines.rs3
-rw-r--r--src/io/util/mem.rs223
-rw-r--r--src/io/util/mod.rs13
-rw-r--r--src/io/util/read.rs34
-rw-r--r--src/io/util/read_buf.rs44
-rw-r--r--src/io/util/read_exact.rs46
-rw-r--r--src/io/util/read_int.rs48
-rw-r--r--src/io/util/read_line.rs86
-rw-r--r--src/io/util/read_to_end.rs153
-rw-r--r--src/io/util/read_to_string.rs96
-rw-r--r--src/io/util/read_until.rs34
-rw-r--r--src/io/util/repeat.rs16
-rw-r--r--src/io/util/shutdown.rs30
-rw-r--r--src/io/util/split.rs3
-rw-r--r--src/io/util/stream_reader.rs184
-rw-r--r--src/io/util/take.rs26
-rw-r--r--src/io/util/write.rs17
-rw-r--r--src/io/util/write_all.rs34
-rw-r--r--src/io/util/write_buf.rs25
-rw-r--r--src/io/util/write_int.rs18
49 files changed, 2374 insertions, 1728 deletions
diff --git a/src/io/async_read.rs b/src/io/async_read.rs
index 1aef415..d075443 100644
--- a/src/io/async_read.rs
+++ b/src/io/async_read.rs
@@ -1,6 +1,5 @@
-use bytes::BufMut;
+use super::ReadBuf;
use std::io;
-use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -16,9 +15,10 @@ use std::task::{Context, Poll};
/// Specifically, this means that the `poll_read` function will return one of
/// the following:
///
-/// * `Poll::Ready(Ok(n))` means that `n` bytes of data was immediately read
-/// and placed into the output buffer, where `n` == 0 implies that EOF has
-/// been reached.
+/// * `Poll::Ready(Ok(()))` means that data was immediately read and placed into
+/// the output buffer. The amount of data read can be determined by the
+/// increase in the length of the slice returned by `ReadBuf::filled`. If the
+/// difference is 0, EOF has been reached.
///
/// * `Poll::Pending` means that no data was read into the buffer
/// provided. The I/O object is not currently readable but may become readable
@@ -41,110 +41,29 @@ use std::task::{Context, Poll};
/// [`Read::read`]: std::io::Read::read
/// [`AsyncReadExt`]: crate::io::AsyncReadExt
pub trait AsyncRead {
- /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns
- /// `true` if the supplied buffer was zeroed out.
- ///
- /// While it would be highly unusual, implementations of [`io::Read`] are
- /// able to read data from the buffer passed as an argument. Because of
- /// this, the buffer passed to [`io::Read`] must be initialized memory. In
- /// situations where large numbers of buffers are used, constantly having to
- /// zero out buffers can be expensive.
- ///
- /// This function does any necessary work to prepare an uninitialized buffer
- /// to be safe to pass to `read`. If `read` guarantees to never attempt to
- /// read data out of the supplied buffer, then `prepare_uninitialized_buffer`
- /// doesn't need to do any work.
- ///
- /// If this function returns `true`, then the memory has been zeroed out.
- /// This allows implementations of `AsyncRead` which are composed of
- /// multiple subimplementations to efficiently implement
- /// `prepare_uninitialized_buffer`.
- ///
- /// This function isn't actually `unsafe` to call but `unsafe` to implement.
- /// The implementer must ensure that either the whole `buf` has been zeroed
- /// or `poll_read_buf()` overwrites the buffer without reading it and returns
- /// correct value.
- ///
- /// This function is called from [`poll_read_buf`].
- ///
- /// # Safety
- ///
- /// Implementations that return `false` must never read from data slices
- /// that they did not write to.
- ///
- /// [`io::Read`]: std::io::Read
- /// [`poll_read_buf`]: method@Self::poll_read_buf
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- for x in buf {
- *x = MaybeUninit::new(0);
- }
-
- true
- }
-
/// Attempts to read from the `AsyncRead` into `buf`.
///
- /// On success, returns `Poll::Ready(Ok(num_bytes_read))`.
+ /// On success, returns `Poll::Ready(Ok(()))` and fills `buf` with data
+ /// read. If no data was read (`buf.filled().is_empty()`) it implies that
+ /// EOF has been reached.
///
- /// If no data is available for reading, the method returns
- /// `Poll::Pending` and arranges for the current task (via
- /// `cx.waker()`) to receive a notification when the object becomes
- /// readable or is closed.
+ /// If no data is available for reading, the method returns `Poll::Pending`
+ /// and arranges for the current task (via `cx.waker()`) to receive a
+ /// notification when the object becomes readable or is closed.
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>>;
-
- /// Pulls some bytes from this source into the specified `BufMut`, returning
- /// how many bytes were read.
- ///
- /// The `buf` provided will have bytes read into it and the internal cursor
- /// will be advanced if any bytes were read. Note that this method typically
- /// will not reallocate the buffer provided.
- fn poll_read_buf<B: BufMut>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>>
- where
- Self: Sized,
- {
- if !buf.has_remaining_mut() {
- return Poll::Ready(Ok(0));
- }
-
- unsafe {
- let n = {
- let b = buf.bytes_mut();
-
- self.prepare_uninitialized_buffer(b);
-
- // Convert to `&mut [u8]`
- let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]);
-
- let n = ready!(self.poll_read(cx, b))?;
- assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold");
- n
- };
-
- buf.advance_mut(n);
- Poll::Ready(Ok(n))
- }
- }
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>>;
}
macro_rules! deref_async_read {
() => {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
@@ -163,43 +82,50 @@ where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- (**self).prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
impl AsyncRead for &[u8] {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let amt = std::cmp::min(self.len(), buf.remaining());
+ let (a, b) = self.split_at(amt);
+ buf.put_slice(a);
+ *self = b;
+ Poll::Ready(Ok(()))
}
}
impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-
fn poll_read(
- self: Pin<&mut Self>,
+ mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(io::Read::read(self.get_mut(), buf))
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let pos = self.position();
+ let slice: &[u8] = (*self).get_ref().as_ref();
+
+ // The position could technically be out of bounds, so don't panic...
+ if pos > slice.len() as u64 {
+ return Poll::Ready(Ok(()));
+ }
+
+ let start = pos as usize;
+ let amt = std::cmp::min(slice.len() - start, buf.remaining());
+ // Add won't overflow because of pos check above.
+ let end = start + amt;
+ buf.put_slice(&slice[start..end]);
+ self.set_position(end as u64);
+
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/async_seek.rs b/src/io/async_seek.rs
index 32ed0a2..bd7a992 100644
--- a/src/io/async_seek.rs
+++ b/src/io/async_seek.rs
@@ -23,36 +23,33 @@ pub trait AsyncSeek {
///
/// If this function returns successfully, then the job has been submitted.
/// To find out when it completes, call `poll_complete`.
- fn start_seek(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- position: SeekFrom,
- ) -> Poll<io::Result<()>>;
+ ///
+ /// # Errors
+ ///
+ /// This function can return [`io::ErrorKind::Other`] in case there is
+ /// another seek in progress. To avoid this, it is advisable that any call
+ /// to `start_seek` is preceded by a call to `poll_complete` to ensure all
+ /// pending seeks have completed.
+ fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()>;
/// Waits for a seek operation to complete.
///
/// If the seek operation completed successfully,
/// this method returns the new position from the start of the stream.
- /// That position can be used later with [`SeekFrom::Start`].
+ /// That position can be used later with [`SeekFrom::Start`]. Repeatedly
+ /// calling this function without calling `start_seek` might return the
+ /// same result.
///
/// # Errors
///
/// Seeking to a negative offset is considered an error.
- ///
- /// # Panics
- ///
- /// Calling this method without calling `start_seek` first is an error.
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>;
}
macro_rules! deref_async_seek {
() => {
- fn start_seek(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- Pin::new(&mut **self).start_seek(cx, pos)
+ fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ Pin::new(&mut **self).start_seek(pos)
}
fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
@@ -74,12 +71,8 @@ where
P: DerefMut + Unpin,
P::Target: AsyncSeek,
{
- fn start_seek(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- self.get_mut().as_mut().start_seek(cx, pos)
+ fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ self.get_mut().as_mut().start_seek(pos)
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
@@ -88,12 +81,8 @@ where
}
impl<T: AsRef<[u8]> + Unpin> AsyncSeek for io::Cursor<T> {
- fn start_seek(
- mut self: Pin<&mut Self>,
- _: &mut Context<'_>,
- pos: SeekFrom,
- ) -> Poll<io::Result<()>> {
- Poll::Ready(io::Seek::seek(&mut *self, pos).map(drop))
+ fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
+ io::Seek::seek(&mut *self, pos).map(drop)
}
fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
Poll::Ready(Ok(self.get_mut().position()))
diff --git a/src/io/async_write.rs b/src/io/async_write.rs
index ecf7575..66ba4bf 100644
--- a/src/io/async_write.rs
+++ b/src/io/async_write.rs
@@ -1,4 +1,3 @@
-use bytes::Buf;
use std::io;
use std::ops::DerefMut;
use std::pin::Pin;
@@ -128,27 +127,6 @@ pub trait AsyncWrite {
/// This function will panic if not called within the context of a future's
/// task.
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;
-
- /// Writes a `Buf` into this value, returning how many bytes were written.
- ///
- /// Note that this method will advance the `buf` provided automatically by
- /// the number of bytes written.
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<Result<usize, io::Error>>
- where
- Self: Sized,
- {
- if !buf.has_remaining() {
- return Poll::Ready(Ok(0));
- }
-
- let n = ready!(self.poll_write(cx, buf.bytes()))?;
- buf.advance(n);
- Poll::Ready(Ok(n))
- }
}
macro_rules! deref_async_write {
diff --git a/src/io/blocking.rs b/src/io/blocking.rs
index 2491039..430801e 100644
--- a/src/io/blocking.rs
+++ b/src/io/blocking.rs
@@ -1,5 +1,5 @@
use crate::io::sys;
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::cmp;
use std::future::Future;
@@ -53,17 +53,17 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- dst: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ dst: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
loop {
match self.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
if !buf.is_empty() {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
*buf_cell = Some(buf);
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
buf.ensure_capacity_for(dst);
@@ -80,9 +80,9 @@ where
match res {
Ok(_) => {
- let n = buf.copy_to(dst);
+ buf.copy_to(dst);
self.state = Idle(Some(buf));
- return Ready(Ok(n));
+ return Ready(Ok(()));
}
Err(e) => {
assert!(buf.is_empty());
@@ -203,9 +203,9 @@ impl Buf {
self.buf.len() - self.pos
}
- pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize {
- let n = cmp::min(self.len(), dst.len());
- dst[..n].copy_from_slice(&self.bytes()[..n]);
+ pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
+ let n = cmp::min(self.len(), dst.remaining());
+ dst.put_slice(&self.bytes()[..n]);
self.pos += n;
if self.pos == self.buf.len() {
@@ -229,10 +229,10 @@ impl Buf {
&self.buf[self.pos..]
}
- pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) {
+ pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
assert!(self.is_empty());
- let len = cmp::min(bytes.len(), MAX_BUF);
+ let len = cmp::min(bytes.remaining(), MAX_BUF);
if self.buf.len() < len {
self.buf.reserve(len - self.buf.len());
diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs
index dbfb6e1..cd82b26 100644
--- a/src/io/driver/mod.rs
+++ b/src/io/driver/mod.rs
@@ -1,30 +1,38 @@
-pub(crate) mod platform;
+#![cfg_attr(not(feature = "rt"), allow(dead_code))]
+
+mod ready;
+use ready::Ready;
mod scheduled_io;
pub(crate) use scheduled_io::ScheduledIo; // pub(crate) for tests
-use crate::loom::sync::atomic::AtomicUsize;
use crate::park::{Park, Unpark};
-use crate::runtime::context;
-use crate::util::slab::{Address, Slab};
+use crate::util::bit;
+use crate::util::slab::{self, Slab};
-use mio::event::Evented;
use std::fmt;
use std::io;
-use std::sync::atomic::Ordering::SeqCst;
use std::sync::{Arc, Weak};
-use std::task::Waker;
use std::time::Duration;
/// I/O driver, backed by Mio
pub(crate) struct Driver {
+ /// Tracks the number of times `turn` is called. It is safe for this to wrap
+ /// as it is mostly used to determine when to call `compact()`
+ tick: u8,
+
/// Reuse the `mio::Events` value across calls to poll.
- events: mio::Events,
+ events: Option<mio::Events>,
+
+ /// Primary slab handle containing the state for each resource registered
+ /// with this driver.
+ resources: Slab<ScheduledIo>,
+
+ /// The system event queue
+ poll: mio::Poll,
/// State shared between the reactor and the handles.
inner: Arc<Inner>,
-
- _wakeup_registration: mio::Registration,
}
/// A reference to an I/O driver
@@ -33,18 +41,20 @@ pub(crate) struct Handle {
inner: Weak<Inner>,
}
-pub(super) struct Inner {
- /// The underlying system event queue.
- io: mio::Poll,
+pub(crate) struct ReadyEvent {
+ tick: u8,
+ ready: Ready,
+}
- /// Dispatch slabs for I/O and futures events
- pub(super) io_dispatch: Slab<ScheduledIo>,
+pub(super) struct Inner {
+ /// Registers I/O resources
+ registry: mio::Registry,
- /// The number of sources in `io_dispatch`.
- n_sources: AtomicUsize,
+ /// Allocates `ScheduledIo` handles when creating new resources.
+ pub(super) io_dispatch: slab::Allocator<ScheduledIo>,
/// Used to wake up the reactor from a call to `turn`
- wakeup: mio::SetReadiness,
+ waker: mio::Waker,
}
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
@@ -53,7 +63,24 @@ pub(super) enum Direction {
Write,
}
-const TOKEN_WAKEUP: mio::Token = mio::Token(Address::NULL);
+enum Tick {
+ Set(u8),
+ Clear(u8),
+}
+
+// TODO: Don't use a fake token. Instead, reserve a slot entry for the wakeup
+// token.
+const TOKEN_WAKEUP: mio::Token = mio::Token(1 << 31);
+
+const ADDRESS: bit::Pack = bit::Pack::least_significant(24);
+
+// Packs the generation value in the `readiness` field.
+//
+// The generation prevents a race condition where a slab slot is reused for a
+// new socket while the I/O driver is about to apply a readiness event. The
+// generaton value is checked when setting new readiness. If the generation do
+// not match, then the readiness event is discarded.
+const GENERATION: bit::Pack = ADDRESS.then(7);
fn _assert_kinds() {
fn _assert<T: Send + Sync>() {}
@@ -67,24 +94,22 @@ impl Driver {
/// Creates a new event loop, returning any error that happened during the
/// creation.
pub(crate) fn new() -> io::Result<Driver> {
- let io = mio::Poll::new()?;
- let wakeup_pair = mio::Registration::new2();
+ let poll = mio::Poll::new()?;
+ let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?;
+ let registry = poll.registry().try_clone()?;
- io.register(
- &wakeup_pair.0,
- TOKEN_WAKEUP,
- mio::Ready::readable(),
- mio::PollOpt::level(),
- )?;
+ let slab = Slab::new();
+ let allocator = slab.allocator();
Ok(Driver {
- events: mio::Events::with_capacity(1024),
- _wakeup_registration: wakeup_pair.0,
+ tick: 0,
+ events: Some(mio::Events::with_capacity(1024)),
+ resources: slab,
+ poll,
inner: Arc::new(Inner {
- io,
- io_dispatch: Slab::new(),
- n_sources: AtomicUsize::new(0),
- wakeup: wakeup_pair.1,
+ registry,
+ io_dispatch: allocator,
+ waker,
}),
})
}
@@ -102,65 +127,66 @@ impl Driver {
}
fn turn(&mut self, max_wait: Option<Duration>) -> io::Result<()> {
+ // How often to call `compact()` on the resource slab
+ const COMPACT_INTERVAL: u8 = 255;
+
+ self.tick = self.tick.wrapping_add(1);
+
+ if self.tick == COMPACT_INTERVAL {
+ self.resources.compact();
+ }
+
+ let mut events = self.events.take().expect("i/o driver event store missing");
+
// Block waiting for an event to happen, peeling out how many events
// happened.
- match self.inner.io.poll(&mut self.events, max_wait) {
+ match self.poll.poll(&mut events, max_wait) {
Ok(_) => {}
+ Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
// Process all the events that came in, dispatching appropriately
-
- for event in self.events.iter() {
+ for event in events.iter() {
let token = event.token();
- if token == TOKEN_WAKEUP {
- self.inner
- .wakeup
- .set_readiness(mio::Ready::empty())
- .unwrap();
- } else {
- self.dispatch(token, event.readiness());
+ if token != TOKEN_WAKEUP {
+ self.dispatch(token, Ready::from_mio(event));
}
}
+ self.events = Some(events);
+
Ok(())
}
- fn dispatch(&self, token: mio::Token, ready: mio::Ready) {
- let mut rd = None;
- let mut wr = None;
+ fn dispatch(&mut self, token: mio::Token, ready: Ready) {
+ let addr = slab::Address::from_usize(ADDRESS.unpack(token.0));
- let address = Address::from_usize(token.0);
-
- let io = match self.inner.io_dispatch.get(address) {
+ let io = match self.resources.get(addr) {
Some(io) => io,
None => return,
};
- if io
- .set_readiness(address, |curr| curr | ready.as_usize())
- .is_err()
- {
+ let res = io.set_readiness(Some(token.0), Tick::Set(self.tick), |curr| curr | ready);
+
+ if res.is_err() {
// token no longer valid!
return;
}
- if ready.is_writable() || platform::is_hup(ready) || platform::is_error(ready) {
- wr = io.writer.take_waker();
- }
-
- if !(ready & (!mio::Ready::writable())).is_empty() {
- rd = io.reader.take_waker();
- }
-
- if let Some(w) = rd {
- w.wake();
- }
+ io.wake(ready);
+ }
+}
- if let Some(w) = wr {
- w.wake();
- }
+impl Drop for Driver {
+ fn drop(&mut self) {
+ self.resources.for_each(|io| {
+ // If a task is waiting on the I/O resource, notify it. The task
+ // will then attempt to use the I/O resource and fail due to the
+ // driver being shutdown.
+ io.wake(Ready::ALL);
+ })
}
}
@@ -181,6 +207,8 @@ impl Park for Driver {
self.turn(Some(duration))?;
Ok(())
}
+
+ fn shutdown(&mut self) {}
}
impl fmt::Debug for Driver {
@@ -191,17 +219,36 @@ impl fmt::Debug for Driver {
// ===== impl Handle =====
-impl Handle {
- /// Returns a handle to the current reactor
- ///
- /// # Panics
- ///
- /// This function panics if there is no current reactor set.
- pub(super) fn current() -> Self {
- context::io_handle()
- .expect("there is no reactor running, must be called from the context of Tokio runtime")
+cfg_rt! {
+ impl Handle {
+ /// Returns a handle to the current reactor
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current reactor set and `rt` feature
+ /// flag is not enabled.
+ pub(super) fn current() -> Self {
+ crate::runtime::context::io_handle()
+ .expect("there is no reactor running, must be called from the context of Tokio runtime")
+ }
+ }
+}
+
+cfg_not_rt! {
+ impl Handle {
+ /// Returns a handle to the current reactor
+ ///
+ /// # Panics
+ ///
+ /// This function panics if there is no current reactor set, or if the `rt`
+ /// feature flag is not enabled.
+ pub(super) fn current() -> Self {
+ panic!("there is no reactor running, must be called from the context of Tokio runtime with `rt` enabled.")
+ }
}
+}
+impl Handle {
/// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise
/// makes the next call to `turn` return immediately.
///
@@ -213,7 +260,7 @@ impl Handle {
/// return immediately.
fn wakeup(&self) {
if let Some(inner) = self.inner() {
- inner.wakeup.set_readiness(mio::Ready::readable()).unwrap();
+ inner.waker.wake().expect("failed to wake I/O driver");
}
}
@@ -242,159 +289,35 @@ impl Inner {
/// The registration token is returned.
pub(super) fn add_source(
&self,
- source: &dyn Evented,
- ready: mio::Ready,
- ) -> io::Result<Address> {
- let address = self.io_dispatch.alloc().ok_or_else(|| {
+ source: &mut impl mio::event::Source,
+ interest: mio::Interest,
+ ) -> io::Result<slab::Ref<ScheduledIo>> {
+ let (address, shared) = self.io_dispatch.allocate().ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"reactor at max registered I/O resources",
)
})?;
- self.n_sources.fetch_add(1, SeqCst);
+ let token = GENERATION.pack(shared.generation(), ADDRESS.pack(address.as_usize(), 0));
- self.io.register(
- source,
- mio::Token(address.to_usize()),
- ready,
- mio::PollOpt::edge(),
- )?;
+ self.registry
+ .register(source, mio::Token(token), interest)?;
- Ok(address)
+ Ok(shared)
}
/// Deregisters an I/O resource from the reactor.
- pub(super) fn deregister_source(&self, source: &dyn Evented) -> io::Result<()> {
- self.io.deregister(source)
- }
-
- pub(super) fn drop_source(&self, address: Address) {
- self.io_dispatch.remove(address);
- self.n_sources.fetch_sub(1, SeqCst);
- }
-
- /// Registers interest in the I/O resource associated with `token`.
- pub(super) fn register(&self, token: Address, dir: Direction, w: Waker) {
- let sched = self
- .io_dispatch
- .get(token)
- .unwrap_or_else(|| panic!("IO resource for token {:?} does not exist!", token));
-
- let waker = match dir {
- Direction::Read => &sched.reader,
- Direction::Write => &sched.writer,
- };
-
- waker.register(w);
+ pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> {
+ self.registry.deregister(source)
}
}
impl Direction {
- pub(super) fn mask(self) -> mio::Ready {
+ pub(super) fn mask(self) -> Ready {
match self {
- Direction::Read => {
- // Everything except writable is signaled through read.
- mio::Ready::all() - mio::Ready::writable()
- }
- Direction::Write => mio::Ready::writable() | platform::hup() | platform::error(),
- }
- }
-}
-
-#[cfg(all(test, loom))]
-mod tests {
- use super::*;
- use loom::thread;
-
- // No-op `Evented` impl just so we can have something to pass to `add_source`.
- struct NotEvented;
-
- impl Evented for NotEvented {
- fn register(
- &self,
- _: &mio::Poll,
- _: mio::Token,
- _: mio::Ready,
- _: mio::PollOpt,
- ) -> io::Result<()> {
- Ok(())
- }
-
- fn reregister(
- &self,
- _: &mio::Poll,
- _: mio::Token,
- _: mio::Ready,
- _: mio::PollOpt,
- ) -> io::Result<()> {
- Ok(())
- }
-
- fn deregister(&self, _: &mio::Poll) -> io::Result<()> {
- Ok(())
+ Direction::Read => Ready::READABLE | Ready::READ_CLOSED,
+ Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED,
}
}
-
- #[test]
- fn tokens_unique_when_dropped() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let thread = thread::spawn(move || {
- inner2.drop_source(token_1);
- });
-
- let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
-
- #[test]
- fn tokens_unique_when_dropped_on_full_page() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
- // add sources to fill up the first page so that the dropped index
- // may be reused.
- for _ in 0..31 {
- inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- }
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let thread = thread::spawn(move || {
- inner2.drop_source(token_1);
- });
-
- let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
-
- #[test]
- fn tokens_unique_concurrent_add() {
- loom::model(|| {
- let reactor = Driver::new().unwrap();
- let inner = reactor.inner;
- let inner2 = inner.clone();
-
- let thread = thread::spawn(move || {
- let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap();
- token_2
- });
-
- let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap();
- let token_2 = thread.join().unwrap();
-
- assert!(token_1 != token_2);
- })
- }
}
diff --git a/src/io/driver/ready.rs b/src/io/driver/ready.rs
new file mode 100644
index 0000000..8b556e9
--- /dev/null
+++ b/src/io/driver/ready.rs
@@ -0,0 +1,187 @@
+use std::fmt;
+use std::ops;
+
+const READABLE: usize = 0b0_01;
+const WRITABLE: usize = 0b0_10;
+const READ_CLOSED: usize = 0b0_0100;
+const WRITE_CLOSED: usize = 0b0_1000;
+
+/// A set of readiness event kinds.
+///
+/// `Ready` is set of operation descriptors indicating which kind of an
+/// operation is ready to be performed.
+///
+/// This struct only represents portable event kinds. Portable events are
+/// events that can be raised on any platform while guaranteeing no false
+/// positives.
+#[derive(Clone, Copy, PartialEq, PartialOrd)]
+pub(crate) struct Ready(usize);
+
+impl Ready {
+ /// Returns the empty `Ready` set.
+ pub(crate) const EMPTY: Ready = Ready(0);
+
+ /// Returns a `Ready` representing readable readiness.
+ pub(crate) const READABLE: Ready = Ready(READABLE);
+
+ /// Returns a `Ready` representing writable readiness.
+ pub(crate) const WRITABLE: Ready = Ready(WRITABLE);
+
+ /// Returns a `Ready` representing read closed readiness.
+ pub(crate) const READ_CLOSED: Ready = Ready(READ_CLOSED);
+
+ /// Returns a `Ready` representing write closed readiness.
+ pub(crate) const WRITE_CLOSED: Ready = Ready(WRITE_CLOSED);
+
+ /// Returns a `Ready` representing readiness for all operations.
+ pub(crate) const ALL: Ready = Ready(READABLE | WRITABLE | READ_CLOSED | WRITE_CLOSED);
+
+ pub(crate) fn from_mio(event: &mio::event::Event) -> Ready {
+ let mut ready = Ready::EMPTY;
+
+ if event.is_readable() {
+ ready |= Ready::READABLE;
+ }
+
+ if event.is_writable() {
+ ready |= Ready::WRITABLE;
+ }
+
+ if event.is_read_closed() {
+ ready |= Ready::READ_CLOSED;
+ }
+
+ if event.is_write_closed() {
+ ready |= Ready::WRITE_CLOSED;
+ }
+
+ ready
+ }
+
+ /// Returns true if `Ready` is the empty set
+ pub(crate) fn is_empty(self) -> bool {
+ self == Ready::EMPTY
+ }
+
+ /// Returns true if the value includes readable readiness
+ pub(crate) fn is_readable(self) -> bool {
+ self.contains(Ready::READABLE) || self.is_read_closed()
+ }
+
+ /// Returns true if the value includes writable readiness
+ pub(crate) fn is_writable(self) -> bool {
+ self.contains(Ready::WRITABLE) || self.is_write_closed()
+ }
+
+ /// Returns true if the value includes read closed readiness
+ pub(crate) fn is_read_closed(self) -> bool {
+ self.contains(Ready::READ_CLOSED)
+ }
+
+ /// Returns true if the value includes write closed readiness
+ pub(crate) fn is_write_closed(self) -> bool {
+ self.contains(Ready::WRITE_CLOSED)
+ }
+
+ /// Returns true if `self` is a superset of `other`.
+ ///
+ /// `other` may represent more than one readiness operations, in which case
+ /// the function only returns true if `self` contains all readiness
+ /// specified in `other`.
+ pub(crate) fn contains<T: Into<Self>>(self, other: T) -> bool {
+ let other = other.into();
+ (self & other) == other
+ }
+
+ /// Create a `Ready` instance using the given `usize` representation.
+ ///
+ /// The `usize` representation must have been obtained from a call to
+ /// `Readiness::as_usize`.
+ ///
+ /// This function is mainly provided to allow the caller to get a
+ /// readiness value from an `AtomicUsize`.
+ pub(crate) fn from_usize(val: usize) -> Ready {
+ Ready(val & Ready::ALL.as_usize())
+ }
+
+ /// Returns a `usize` representation of the `Ready` value.
+ ///
+ /// This function is mainly provided to allow the caller to store a
+ /// readiness value in an `AtomicUsize`.
+ pub(crate) fn as_usize(self) -> usize {
+ self.0
+ }
+}
+
+cfg_io_readiness! {
+ impl Ready {
+ pub(crate) fn from_interest(interest: mio::Interest) -> Ready {
+ let mut ready = Ready::EMPTY;
+
+ if interest.is_readable() {
+ ready |= Ready::READABLE;
+ ready |= Ready::READ_CLOSED;
+ }
+
+ if interest.is_writable() {
+ ready |= Ready::WRITABLE;
+ ready |= Ready::WRITE_CLOSED;
+ }
+
+ ready
+ }
+
+ pub(crate) fn intersection(self, interest: mio::Interest) -> Ready {
+ Ready(self.0 & Ready::from_interest(interest).0)
+ }
+
+ pub(crate) fn satisfies(self, interest: mio::Interest) -> bool {
+ self.0 & Ready::from_interest(interest).0 != 0
+ }
+ }
+}
+
+impl<T: Into<Ready>> ops::BitOr<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn bitor(self, other: T) -> Ready {
+ Ready(self.0 | other.into().0)
+ }
+}
+
+impl<T: Into<Ready>> ops::BitOrAssign<T> for Ready {
+ #[inline]
+ fn bitor_assign(&mut self, other: T) {
+ self.0 |= other.into().0;
+ }
+}
+
+impl<T: Into<Ready>> ops::BitAnd<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn bitand(self, other: T) -> Ready {
+ Ready(self.0 & other.into().0)
+ }
+}
+
+impl<T: Into<Ready>> ops::Sub<T> for Ready {
+ type Output = Ready;
+
+ #[inline]
+ fn sub(self, other: T) -> Ready {
+ Ready(self.0 & !other.into().0)
+ }
+}
+
+impl fmt::Debug for Ready {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Ready")
+ .field("is_readable", &self.is_readable())
+ .field("is_writable", &self.is_writable())
+ .field("is_read_closed", &self.is_read_closed())
+ .field("is_write_closed", &self.is_write_closed())
+ .finish()
+ }
+}
diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs
index 7f6446e..b1354a0 100644
--- a/src/io/driver/scheduled_io.rs
+++ b/src/io/driver/scheduled_io.rs
@@ -1,47 +1,109 @@
-use crate::loom::future::AtomicWaker;
+use super::{Direction, Ready, ReadyEvent, Tick};
use crate::loom::sync::atomic::AtomicUsize;
+use crate::loom::sync::Mutex;
use crate::util::bit;
-use crate::util::slab::{Address, Entry, Generation};
+use crate::util::slab::Entry;
-use std::sync::atomic::Ordering::{AcqRel, Acquire, SeqCst};
+use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
+use std::task::{Context, Poll, Waker};
+cfg_io_readiness! {
+ use crate::util::linked_list::{self, LinkedList};
+
+ use std::cell::UnsafeCell;
+ use std::future::Future;
+ use std::marker::PhantomPinned;
+ use std::pin::Pin;
+ use std::ptr::NonNull;
+}
+
+/// Stored in the I/O driver resource slab.
#[derive(Debug)]
pub(crate) struct ScheduledIo {
+ /// Packs the resource's readiness with the resource's generation.
readiness: AtomicUsize,
- pub(crate) reader: AtomicWaker,
- pub(crate) writer: AtomicWaker,
+
+ waiters: Mutex<Waiters>,
}
-const PACK: bit::Pack = bit::Pack::most_significant(Generation::WIDTH);
+cfg_io_readiness! {
+ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
+}
-impl Entry for ScheduledIo {
- fn generation(&self) -> Generation {
- unpack_generation(self.readiness.load(SeqCst))
+#[derive(Debug, Default)]
+struct Waiters {
+ #[cfg(feature = "net")]
+ /// List of all current waiters
+ list: WaitList,
+
+ /// Waker used for AsyncRead
+ reader: Option<Waker>,
+
+ /// Waker used for AsyncWrite
+ writer: Option<Waker>,
+}
+
+cfg_io_readiness! {
+ #[derive(Debug)]
+ struct Waiter {
+ pointers: linked_list::Pointers<Waiter>,
+
+ /// The waker for this task
+ waker: Option<Waker>,
+
+ /// The interest this waiter is waiting on
+ interest: mio::Interest,
+
+ is_ready: bool,
+
+ /// Should never be `!Unpin`
+ _p: PhantomPinned,
}
- fn reset(&self, generation: Generation) -> bool {
- let mut current = self.readiness.load(Acquire);
+ /// Future returned by `readiness()`
+ struct Readiness<'a> {
+ scheduled_io: &'a ScheduledIo,
- loop {
- if unpack_generation(current) != generation {
- return false;
- }
+ state: State,
- let next = PACK.pack(generation.next().to_usize(), 0);
+ /// Entry in the waiter `LinkedList`.
+ waiter: UnsafeCell<Waiter>,
+ }
- match self
- .readiness
- .compare_exchange(current, next, AcqRel, Acquire)
- {
- Ok(_) => break,
- Err(actual) => current = actual,
- }
- }
+ enum State {
+ Init,
+ Waiting,
+ Done,
+ }
+}
+
+// The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness.
+//
+// | reserved | generation | driver tick | readinesss |
+// |----------+------------+--------------+------------|
+// | 1 bit | 7 bits + 8 bits + 16 bits |
+
+const READINESS: bit::Pack = bit::Pack::least_significant(16);
+
+const TICK: bit::Pack = READINESS.then(8);
- drop(self.reader.take_waker());
- drop(self.writer.take_waker());
+const GENERATION: bit::Pack = TICK.then(7);
- true
+#[test]
+fn test_generations_assert_same() {
+ assert_eq!(super::GENERATION, GENERATION);
+}
+
+// ===== impl ScheduledIo =====
+
+impl Entry for ScheduledIo {
+ fn reset(&self) {
+ let state = self.readiness.load(Acquire);
+
+ let generation = GENERATION.unpack(state);
+ let next = GENERATION.pack_lossy(generation + 1, 0);
+
+ self.readiness.store(next, Release);
}
}
@@ -49,31 +111,14 @@ impl Default for ScheduledIo {
fn default() -> ScheduledIo {
ScheduledIo {
readiness: AtomicUsize::new(0),
- reader: AtomicWaker::new(),
- writer: AtomicWaker::new(),
+ waiters: Mutex::new(Default::default()),
}
}
}
impl ScheduledIo {
- #[cfg(all(test, loom))]
- /// Returns the current readiness value of this `ScheduledIo`, if the
- /// provided `token` is still a valid access.
- ///
- /// # Returns
- ///
- /// If the given token's generation no longer matches the `ScheduledIo`'s
- /// generation, then the corresponding IO resource has been removed and
- /// replaced with a new resource. In that case, this method returns `None`.
- /// Otherwise, this returns the current readiness.
- pub(crate) fn get_readiness(&self, address: Address) -> Option<usize> {
- let ready = self.readiness.load(Acquire);
-
- if unpack_generation(ready) != address.generation() {
- return None;
- }
-
- Some(ready & !PACK.mask())
+ pub(crate) fn generation(&self) -> usize {
+ GENERATION.unpack(self.readiness.load(Acquire))
}
/// Sets the readiness on this `ScheduledIo` by invoking the given closure on
@@ -81,6 +126,8 @@ impl ScheduledIo {
///
/// # Arguments
/// - `token`: the token for this `ScheduledIo`.
+ /// - `tick`: whether setting the tick or trying to clear readiness for a
+ /// specific tick.
/// - `f`: a closure returning a new readiness value given the previous
/// readiness.
///
@@ -90,52 +137,354 @@ impl ScheduledIo {
/// generation, then the corresponding IO resource has been removed and
/// replaced with a new resource. In that case, this method returns `Err`.
/// Otherwise, this returns the previous readiness.
- pub(crate) fn set_readiness(
+ pub(super) fn set_readiness(
&self,
- address: Address,
- f: impl Fn(usize) -> usize,
- ) -> Result<usize, ()> {
- let generation = address.generation();
-
+ token: Option<usize>,
+ tick: Tick,
+ f: impl Fn(Ready) -> Ready,
+ ) -> Result<(), ()> {
let mut current = self.readiness.load(Acquire);
loop {
- // Check that the generation for this access is still the current
- // one.
- if unpack_generation(current) != generation {
- return Err(());
+ let current_generation = GENERATION.unpack(current);
+
+ if let Some(token) = token {
+ // Check that the generation for this access is still the
+ // current one.
+ if GENERATION.unpack(token) != current_generation {
+ return Err(());
+ }
}
- // Mask out the generation bits so that the modifying function
- // doesn't see them.
- let current_readiness = current & mio::Ready::all().as_usize();
+
+ // Mask out the tick/generation bits so that the modifying
+ // function doesn't see them.
+ let current_readiness = Ready::from_usize(current);
let new = f(current_readiness);
- debug_assert!(
- new <= !PACK.max_value(),
- "new readiness value would overwrite generation bits!"
- );
-
- match self.readiness.compare_exchange(
- current,
- PACK.pack(generation.to_usize(), new),
- AcqRel,
- Acquire,
- ) {
- Ok(_) => return Ok(current),
+ let packed = match tick {
+ Tick::Set(t) => TICK.pack(t as usize, new.as_usize()),
+ Tick::Clear(t) => {
+ if TICK.unpack(current) as u8 != t {
+ // Trying to clear readiness with an old event!
+ return Err(());
+ }
+
+ TICK.pack(t as usize, new.as_usize())
+ }
+ };
+
+ let next = GENERATION.pack(current_generation, packed);
+
+ match self
+ .readiness
+ .compare_exchange(current, next, AcqRel, Acquire)
+ {
+ Ok(_) => return Ok(()),
// we lost the race, retry!
Err(actual) => current = actual,
}
}
}
+
+ /// Notifies all pending waiters that have registered interest in `ready`.
+ ///
+ /// There may be many waiters to notify. Waking the pending task **must** be
+ /// done from outside of the lock otherwise there is a potential for a
+ /// deadlock.
+ ///
+ /// A stack array of wakers is created and filled with wakers to notify, the
+ /// lock is released, and the wakers are notified. Because there may be more
+ /// than 32 wakers to notify, if the stack array fills up, the lock is
+ /// released, the array is cleared, and the iteration continues.
+ pub(super) fn wake(&self, ready: Ready) {
+ const NUM_WAKERS: usize = 32;
+
+ let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default();
+ let mut curr = 0;
+
+ let mut waiters = self.waiters.lock();
+
+ // check for AsyncRead slot
+ if ready.is_readable() {
+ if let Some(waker) = waiters.reader.take() {
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+
+ // check for AsyncWrite slot
+ if ready.is_writable() {
+ if let Some(waker) = waiters.writer.take() {
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+
+ #[cfg(feature = "net")]
+ 'outer: loop {
+ let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest));
+
+ while curr < NUM_WAKERS {
+ match iter.next() {
+ Some(waiter) => {
+ let waiter = unsafe { &mut *waiter.as_ptr() };
+
+ if let Some(waker) = waiter.waker.take() {
+ waiter.is_ready = true;
+ wakers[curr] = Some(waker);
+ curr += 1;
+ }
+ }
+ None => {
+ break 'outer;
+ }
+ }
+ }
+
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
+ }
+
+ curr = 0;
+
+ // Acquire the lock again.
+ waiters = self.waiters.lock();
+ }
+
+ // Release the lock before notifying
+ drop(waiters);
+
+ for waker in wakers.iter_mut().take(curr) {
+ waker.take().unwrap().wake();
+ }
+ }
+
+ /// Poll version of checking readiness for a certain direction.
+ ///
+ /// These are to support `AsyncRead` and `AsyncWrite` polling methods,
+ /// which cannot use the `async fn` version. This uses reserved reader
+ /// and writer slots.
+ pub(in crate::io) fn poll_readiness(
+ &self,
+ cx: &mut Context<'_>,
+ direction: Direction,
+ ) -> Poll<ReadyEvent> {
+ let curr = self.readiness.load(Acquire);
+
+ let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr));
+
+ if ready.is_empty() {
+ // Update the task info
+ let mut waiters = self.waiters.lock();
+ let slot = match direction {
+ Direction::Read => &mut waiters.reader,
+ Direction::Write => &mut waiters.writer,
+ };
+ *slot = Some(cx.waker().clone());
+
+ // Try again, in case the readiness was changed while we were
+ // taking the waiters lock
+ let curr = self.readiness.load(Acquire);
+ let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr));
+ if ready.is_empty() {
+ Poll::Pending
+ } else {
+ Poll::Ready(ReadyEvent {
+ tick: TICK.unpack(curr) as u8,
+ ready,
+ })
+ }
+ } else {
+ Poll::Ready(ReadyEvent {
+ tick: TICK.unpack(curr) as u8,
+ ready,
+ })
+ }
+ }
+
+ pub(crate) fn clear_readiness(&self, event: ReadyEvent) {
+ // This consumes the current readiness state **except** for closed
+ // states. Closed states are excluded because they are final states.
+ let mask_no_closed = event.ready - Ready::READ_CLOSED - Ready::WRITE_CLOSED;
+
+ // result isn't important
+ let _ = self.set_readiness(None, Tick::Clear(event.tick), |curr| curr - mask_no_closed);
+ }
}
impl Drop for ScheduledIo {
fn drop(&mut self) {
- self.writer.wake();
- self.reader.wake();
+ self.wake(Ready::ALL);
}
}
-fn unpack_generation(src: usize) -> Generation {
- Generation::new(PACK.unpack(src))
+unsafe impl Send for ScheduledIo {}
+unsafe impl Sync for ScheduledIo {}
+
+cfg_io_readiness! {
+ impl ScheduledIo {
+ /// An async version of `poll_readiness` which uses a linked list of wakers
+ pub(crate) async fn readiness(&self, interest: mio::Interest) -> ReadyEvent {
+ self.readiness_fut(interest).await
+ }
+
+ // This is in a separate function so that the borrow checker doesn't think
+ // we are borrowing the `UnsafeCell` possibly over await boundaries.
+ //
+ // Go figure.
+ fn readiness_fut(&self, interest: mio::Interest) -> Readiness<'_> {
+ Readiness {
+ scheduled_io: self,
+ state: State::Init,
+ waiter: UnsafeCell::new(Waiter {
+ pointers: linked_list::Pointers::new(),
+ waker: None,
+ is_ready: false,
+ interest,
+ _p: PhantomPinned,
+ }),
+ }
+ }
+ }
+
+ unsafe impl linked_list::Link for Waiter {
+ type Handle = NonNull<Waiter>;
+ type Target = Waiter;
+
+ fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
+ *handle
+ }
+
+ unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
+ ptr
+ }
+
+ unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
+ NonNull::from(&mut target.as_mut().pointers)
+ }
+ }
+
+ // ===== impl Readiness =====
+
+ impl Future for Readiness<'_> {
+ type Output = ReadyEvent;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ use std::sync::atomic::Ordering::SeqCst;
+
+ let (scheduled_io, state, waiter) = unsafe {
+ let me = self.get_unchecked_mut();
+ (&me.scheduled_io, &mut me.state, &me.waiter)
+ };
+
+ loop {
+ match *state {
+ State::Init => {
+ // Optimistically check existing readiness
+ let curr = scheduled_io.readiness.load(SeqCst);
+ let ready = Ready::from_usize(READINESS.unpack(curr));
+
+ // Safety: `waiter.interest` never changes
+ let interest = unsafe { (*waiter.get()).interest };
+ let ready = ready.intersection(interest);
+
+ if !ready.is_empty() {
+ // Currently ready!
+ let tick = TICK.unpack(curr) as u8;
+ *state = State::Done;
+ return Poll::Ready(ReadyEvent { ready, tick });
+ }
+
+ // Wasn't ready, take the lock (and check again while locked).
+ let mut waiters = scheduled_io.waiters.lock();
+
+ let curr = scheduled_io.readiness.load(SeqCst);
+ let ready = Ready::from_usize(READINESS.unpack(curr));
+ let ready = ready.intersection(interest);
+
+ if !ready.is_empty() {
+ // Currently ready!
+ let tick = TICK.unpack(curr) as u8;
+ *state = State::Done;
+ return Poll::Ready(ReadyEvent { ready, tick });
+ }
+
+ // Not ready even after locked, insert into list...
+
+ // Safety: called while locked
+ unsafe {
+ (*waiter.get()).waker = Some(cx.waker().clone());
+ }
+
+ // Insert the waiter into the linked list
+ //
+ // safety: pointers from `UnsafeCell` are never null.
+ waiters
+ .list
+ .push_front(unsafe { NonNull::new_unchecked(waiter.get()) });
+ *state = State::Waiting;
+ }
+ State::Waiting => {
+ // Currently in the "Waiting" state, implying the caller has
+ // a waiter stored in the waiter list (guarded by
+ // `notify.waiters`). In order to access the waker fields,
+ // we must hold the lock.
+
+ let waiters = scheduled_io.waiters.lock();
+
+ // Safety: called while locked
+ let w = unsafe { &mut *waiter.get() };
+
+ if w.is_ready {
+ // Our waker has been notified.
+ *state = State::Done;
+ } else {
+ // Update the waker, if necessary.
+ if !w.waker.as_ref().unwrap().will_wake(cx.waker()) {
+ w.waker = Some(cx.waker().clone());
+ }
+
+ return Poll::Pending;
+ }
+
+ // Explicit drop of the lock to indicate the scope that the
+ // lock is held. Because holding the lock is required to
+ // ensure safe access to fields not held within the lock, it
+ // is helpful to visualize the scope of the critical
+ // section.
+ drop(waiters);
+ }
+ State::Done => {
+ let tick = TICK.unpack(scheduled_io.readiness.load(Acquire)) as u8;
+
+ // Safety: State::Done means it is no longer shared
+ let w = unsafe { &mut *waiter.get() };
+
+ return Poll::Ready(ReadyEvent {
+ tick,
+ ready: Ready::from_interest(w.interest),
+ });
+ }
+ }
+ }
+ }
+ }
+
+ impl Drop for Readiness<'_> {
+ fn drop(&mut self) {
+ let mut waiters = self.scheduled_io.waiters.lock();
+
+ // Safety: `waiter` is only ever stored in `waiters`
+ unsafe {
+ waiters
+ .list
+ .remove(NonNull::new_unchecked(self.waiter.get()))
+ };
+ }
+ }
+
+ unsafe impl Send for Readiness<'_> {}
+ unsafe impl Sync for Readiness<'_> {}
}
diff --git a/src/io/mod.rs b/src/io/mod.rs
index 7b00556..9191bbc 100644
--- a/src/io/mod.rs
+++ b/src/io/mod.rs
@@ -162,8 +162,8 @@
//!
//! # `std` re-exports
//!
-//! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported
-//! from `std::io` for ease of use.
+//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are
+//! re-exported from `std::io` for ease of use.
//!
//! [`AsyncRead`]: trait@AsyncRead
//! [`AsyncWrite`]: trait@AsyncWrite
@@ -176,6 +176,7 @@
//! [`ErrorKind`]: enum@ErrorKind
//! [`Result`]: type@Result
//! [`Read`]: std::io::Read
+//! [`SeekFrom`]: enum@SeekFrom
//! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html
//! [`Stream`]: crate::stream::Stream
//! [`Write`]: std::io::Write
@@ -187,7 +188,6 @@ mod async_buf_read;
pub use self::async_buf_read::AsyncBufRead;
mod async_read;
-
pub use self::async_read::AsyncRead;
mod async_seek;
@@ -196,17 +196,27 @@ pub use self::async_seek::AsyncSeek;
mod async_write;
pub use self::async_write::AsyncWrite;
+mod read_buf;
+pub use self::read_buf::ReadBuf;
+
+// Re-export some types from `std::io` so that users don't have to deal
+// with conflicts when `use`ing `tokio::io` and `std::io`.
+#[doc(no_inline)]
+pub use std::io::{Error, ErrorKind, Result, SeekFrom};
+
cfg_io_driver! {
pub(crate) mod driver;
mod poll_evented;
- pub use poll_evented::PollEvented;
+ #[cfg(not(loom))]
+ pub(crate) use poll_evented::PollEvented;
mod registration;
- pub use registration::Registration;
}
cfg_io_std! {
+ mod stdio_common;
+
mod stderr;
pub use stderr::{stderr, Stderr};
@@ -222,21 +232,11 @@ cfg_io_util! {
pub use split::{split, ReadHalf, WriteHalf};
pub(crate) mod seek;
- pub use self::seek::Seek;
-
pub(crate) mod util;
pub use util::{
- copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
- BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take,
+ copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
+ BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take,
};
-
- cfg_stream! {
- pub use util::{stream_reader, StreamReader};
- }
-
- // Re-export io::Error so that users don't have to deal with conflicts when
- // `use`ing `tokio::io` and `std::io`.
- pub use std::io::{Error, ErrorKind, Result};
}
cfg_not_io_util! {
@@ -249,7 +249,7 @@ cfg_io_blocking! {
/// Types in this module can be mocked out in tests.
mod sys {
// TODO: don't rename
- pub(crate) use crate::runtime::spawn_blocking as run;
- pub(crate) use crate::task::JoinHandle as Blocking;
+ pub(crate) use crate::blocking::spawn_blocking as run;
+ pub(crate) use crate::blocking::JoinHandle as Blocking;
}
}
diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs
index 5295bd7..66a2634 100644
--- a/src/io/poll_evented.rs
+++ b/src/io/poll_evented.rs
@@ -1,13 +1,12 @@
-use crate::io::driver::platform;
-use crate::io::{AsyncRead, AsyncWrite, Registration};
+use crate::io::driver::{Direction, Handle, ReadyEvent};
+use crate::io::registration::Registration;
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
-use mio::event::Evented;
+use mio::event::Source;
use std::fmt;
use std::io::{self, Read, Write};
use std::marker::Unpin;
use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
-use std::sync::atomic::Ordering::Relaxed;
use std::task::{Context, Poll};
cfg_io_driver! {
@@ -53,37 +52,6 @@ cfg_io_driver! {
/// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and
/// [`clear_read_ready`].
///
- /// ```rust
- /// use tokio::io::PollEvented;
- ///
- /// use futures::ready;
- /// use mio::Ready;
- /// use mio::net::{TcpStream, TcpListener};
- /// use std::io;
- /// use std::task::{Context, Poll};
- ///
- /// struct MyListener {
- /// poll_evented: PollEvented<TcpListener>,
- /// }
- ///
- /// impl MyListener {
- /// pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<TcpStream, io::Error>> {
- /// let ready = Ready::readable();
- ///
- /// ready!(self.poll_evented.poll_read_ready(cx, ready))?;
- ///
- /// match self.poll_evented.get_ref().accept() {
- /// Ok((socket, _)) => Poll::Ready(Ok(socket)),
- /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- /// self.poll_evented.clear_read_ready(cx, ready)?;
- /// Poll::Pending
- /// }
- /// Err(e) => Poll::Ready(Err(e)),
- /// }
- /// }
- /// }
- /// ```
- ///
/// ## Platform-specific events
///
/// `PollEvented` also allows receiving platform-specific `mio::Ready` events.
@@ -101,70 +69,15 @@ cfg_io_driver! {
/// [`clear_write_ready`]: method@Self::clear_write_ready
/// [`poll_read_ready`]: method@Self::poll_read_ready
/// [`poll_write_ready`]: method@Self::poll_write_ready
- pub struct PollEvented<E: Evented> {
+ pub(crate) struct PollEvented<E: Source> {
io: Option<E>,
- inner: Inner,
+ registration: Registration,
}
}
-struct Inner {
- registration: Registration,
-
- /// Currently visible read readiness
- read_readiness: AtomicUsize,
-
- /// Currently visible write readiness
- write_readiness: AtomicUsize,
-}
-
// ===== impl PollEvented =====
-macro_rules! poll_ready {
- ($me:expr, $mask:expr, $cache:ident, $take:ident, $poll:expr) => {{
- // Load cached & encoded readiness.
- let mut cached = $me.inner.$cache.load(Relaxed);
- let mask = $mask | platform::hup() | platform::error();
-
- // See if the current readiness matches any bits.
- let mut ret = mio::Ready::from_usize(cached) & $mask;
-
- if ret.is_empty() {
- // Readiness does not match, consume the registration's readiness
- // stream. This happens in a loop to ensure that the stream gets
- // drained.
- loop {
- let ready = match $poll? {
- Poll::Ready(v) => v,
- Poll::Pending => return Poll::Pending,
- };
- cached |= ready.as_usize();
-
- // Update the cache store
- $me.inner.$cache.store(cached, Relaxed);
-
- ret |= ready & mask;
-
- if !ret.is_empty() {
- return Poll::Ready(Ok(ret));
- }
- }
- } else {
- // Check what's new with the registration stream. This will not
- // request to be notified
- if let Some(ready) = $me.inner.registration.$take()? {
- cached |= ready.as_usize();
- $me.inner.$cache.store(cached, Relaxed);
- }
-
- Poll::Ready(Ok(mio::Ready::from_usize(cached)))
- }
- }};
-}
-
-impl<E> PollEvented<E>
-where
- E: Evented,
-{
+impl<E: Source> PollEvented<E> {
/// Creates a new `PollEvented` associated with the default reactor.
///
/// # Panics
@@ -173,71 +86,57 @@ where
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new(io: E) -> io::Result<Self> {
- PollEvented::new_with_ready(io, mio::Ready::all())
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ #[cfg_attr(feature = "signal", allow(unused))]
+ pub(crate) fn new(io: E) -> io::Result<Self> {
+ PollEvented::new_with_interest(io, mio::Interest::READABLE | mio::Interest::WRITABLE)
}
- /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready`
- /// state. `new_with_ready` should be used over `new` when you need control over the readiness
+ /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Interest`
+ /// state. `new_with_interest` should be used over `new` when you need control over the readiness
/// state, such as when a file descriptor only allows reads. This does not add `hup` or `error`
/// so if you are interested in those states, you will need to add them to the readiness state
/// passed to this function.
///
- /// An example to listen to read only
- ///
- /// ```rust
- /// ##[cfg(unix)]
- /// mio::Ready::from_usize(
- /// mio::Ready::readable().as_usize()
- /// | mio::unix::UnixReady::error().as_usize()
- /// | mio::unix::UnixReady::hup().as_usize()
- /// );
- /// ```
- ///
/// # Panics
///
/// This function panics if thread-local runtime is not set.
///
/// The runtime is usually set implicitly when this function is called
/// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result<Self> {
- let registration = Registration::new_with_ready(&io, ready)?;
+ /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function.
+ #[cfg_attr(feature = "signal", allow(unused))]
+ pub(crate) fn new_with_interest(io: E, interest: mio::Interest) -> io::Result<Self> {
+ Self::new_with_interest_and_handle(io, interest, Handle::current())
+ }
+
+ pub(crate) fn new_with_interest_and_handle(
+ mut io: E,
+ interest: mio::Interest,
+ handle: Handle,
+ ) -> io::Result<Self> {
+ let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?;
Ok(Self {
io: Some(io),
- inner: Inner {
- registration,
- read_readiness: AtomicUsize::new(0),
- write_readiness: AtomicUsize::new(0),
- },
+ registration,
})
}
/// Returns a shared reference to the underlying I/O object this readiness
/// stream is wrapping.
- pub fn get_ref(&self) -> &E {
+ #[cfg(any(feature = "net", feature = "process", feature = "signal"))]
+ pub(crate) fn get_ref(&self) -> &E {
self.io.as_ref().unwrap()
}
/// Returns a mutable reference to the underlying I/O object this readiness
/// stream is wrapping.
- pub fn get_mut(&mut self) -> &mut E {
+ pub(crate) fn get_mut(&mut self) -> &mut E {
self.io.as_mut().unwrap()
}
- /// Consumes self, returning the inner I/O object
- ///
- /// This function will deregister the I/O resource from the reactor before
- /// returning. If the deregistration operation fails, an error is returned.
- ///
- /// Note that deregistering does not guarantee that the I/O resource can be
- /// registered with a different reactor. Some I/O resource types can only be
- /// associated with a single reactor instance for their lifetime.
- pub fn into_inner(mut self) -> io::Result<E> {
- let io = self.io.take().unwrap();
- self.inner.registration.deregister(&io)?;
- Ok(io)
+ pub(crate) fn clear_readiness(&self, event: ReadyEvent) {
+ self.registration.clear_readiness(event);
}
/// Checks the I/O resource's read readiness state.
@@ -266,51 +165,8 @@ where
///
/// This method may not be called concurrently. It takes `&self` to allow
/// calling it concurrently with `poll_write_ready`.
- pub fn poll_read_ready(
- &self,
- cx: &mut Context<'_>,
- mask: mio::Ready,
- ) -> Poll<io::Result<mio::Ready>> {
- assert!(!mask.is_writable(), "cannot poll for write readiness");
- poll_ready!(
- self,
- mask,
- read_readiness,
- take_read_ready,
- self.inner.registration.poll_read_ready(cx)
- )
- }
-
- /// Clears the I/O resource's read readiness state and registers the current
- /// task to be notified once a read readiness event is received.
- ///
- /// After calling this function, `poll_read_ready` will return
- /// `Poll::Pending` until a new read readiness event has been received.
- ///
- /// The `mask` argument specifies the readiness bits to clear. This may not
- /// include `writable` or `hup`.
- ///
- /// # Panics
- ///
- /// This function panics if:
- ///
- /// * `ready` includes writable or HUP
- /// * called from outside of a task context.
- pub fn clear_read_ready(&self, cx: &mut Context<'_>, ready: mio::Ready) -> io::Result<()> {
- // Cannot clear write readiness
- assert!(!ready.is_writable(), "cannot clear write readiness");
- assert!(!platform::is_hup(ready), "cannot clear HUP readiness");
-
- self.inner
- .read_readiness
- .fetch_and(!ready.as_usize(), Relaxed);
-
- if self.poll_read_ready(cx, ready)?.is_ready() {
- // Notify the current task
- cx.waker().wake_by_ref();
- }
-
- Ok(())
+ pub(crate) fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> {
+ self.registration.poll_readiness(cx, Direction::Read)
}
/// Checks the I/O resource's write readiness state.
@@ -337,100 +193,95 @@ where
///
/// This method may not be called concurrently. It takes `&self` to allow
/// calling it concurrently with `poll_read_ready`.
- pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- poll_ready!(
- self,
- mio::Ready::writable(),
- write_readiness,
- take_write_ready,
- self.inner.registration.poll_write_ready(cx)
- )
+ pub(crate) fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> {
+ self.registration.poll_readiness(cx, Direction::Write)
}
+}
- /// Resets the I/O resource's write readiness state and registers the current
- /// task to be notified once a write readiness event is received.
- ///
- /// This only clears writable readiness. HUP (on platforms that support HUP)
- /// cannot be cleared as it is a final state.
- ///
- /// After calling this function, `poll_write_ready(Ready::writable())` will
- /// return `NotReady` until a new write readiness event has been received.
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn clear_write_ready(&self, cx: &mut Context<'_>) -> io::Result<()> {
- let ready = mio::Ready::writable();
+cfg_io_readiness! {
+ impl<E: Source> PollEvented<E> {
+ pub(crate) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> {
+ self.registration.readiness(interest).await
+ }
- self.inner
- .write_readiness
- .fetch_and(!ready.as_usize(), Relaxed);
+ pub(crate) async fn async_io<F, R>(&self, interest: mio::Interest, mut op: F) -> io::Result<R>
+ where
+ F: FnMut(&E) -> io::Result<R>,
+ {
+ loop {
+ let event = self.readiness(interest).await?;
- if self.poll_write_ready(cx)?.is_ready() {
- // Notify the current task
- cx.waker().wake_by_ref();
+ match op(self.get_ref()) {
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ self.clear_readiness(event);
+ }
+ x => return x,
+ }
+ }
}
-
- Ok(())
}
}
// ===== Read / Write impls =====
-impl<E> AsyncRead for PollEvented<E>
-where
- E: Evented + Read + Unpin,
-{
+impl<E: Source + Read + Unpin> AsyncRead for PollEvented<E> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- ready!(self.poll_read_ready(cx, mio::Ready::readable()))?;
-
- let r = (*self).get_mut().read(buf);
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ loop {
+ let ev = ready!(self.poll_read_ready(cx))?;
+
+ // We can't assume the `Read` won't look at the read buffer,
+ // so we have to force initialization here.
+ let r = (*self).get_mut().read(buf.initialize_unfilled());
+
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- if is_wouldblock(&r) {
- self.clear_read_ready(cx, mio::Ready::readable())?;
- return Poll::Pending;
+ return Poll::Ready(r.map(|n| {
+ buf.advance(n);
+ }));
}
-
- Poll::Ready(r)
}
}
-impl<E> AsyncWrite for PollEvented<E>
-where
- E: Evented + Write + Unpin,
-{
+impl<E: Source + Write + Unpin> AsyncWrite for PollEvented<E> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
- ready!(self.poll_write_ready(cx))?;
+ loop {
+ let ev = ready!(self.poll_write_ready(cx))?;
- let r = (*self).get_mut().write(buf);
+ let r = (*self).get_mut().write(buf);
- if is_wouldblock(&r) {
- self.clear_write_ready(cx)?;
- return Poll::Pending;
- }
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- Poll::Ready(r)
+ return Poll::Ready(r);
+ }
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- ready!(self.poll_write_ready(cx))?;
+ loop {
+ let ev = ready!(self.poll_write_ready(cx))?;
- let r = (*self).get_mut().flush();
+ let r = (*self).get_mut().flush();
- if is_wouldblock(&r) {
- self.clear_write_ready(cx)?;
- return Poll::Pending;
- }
+ if is_wouldblock(&r) {
+ self.clear_readiness(ev);
+ continue;
+ }
- Poll::Ready(r)
+ return Poll::Ready(r);
+ }
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
@@ -445,17 +296,17 @@ fn is_wouldblock<T>(r: &io::Result<T>) -> bool {
}
}
-impl<E: Evented + fmt::Debug> fmt::Debug for PollEvented<E> {
+impl<E: Source + fmt::Debug> fmt::Debug for PollEvented<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollEvented").field("io", &self.io).finish()
}
}
-impl<E: Evented> Drop for PollEvented<E> {
+impl<E: Source> Drop for PollEvented<E> {
fn drop(&mut self) {
- if let Some(io) = self.io.take() {
+ if let Some(mut io) = self.io.take() {
// Ignore errors
- let _ = self.inner.registration.deregister(&io);
+ let _ = self.registration.deregister(&mut io);
}
}
}
diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs
new file mode 100644
index 0000000..b64d95c
--- /dev/null
+++ b/src/io/read_buf.rs
@@ -0,0 +1,261 @@
+// This lint claims ugly casting is somehow safer than transmute, but there's
+// no evidence that is the case. Shush.
+#![allow(clippy::transmute_ptr_to_ptr)]
+
+use std::fmt;
+use std::mem::{self, MaybeUninit};
+
+/// A wrapper around a byte buffer that is incrementally filled and initialized.
+///
+/// This type is a sort of "double cursor". It tracks three regions in the
+/// buffer: a region at the beginning of the buffer that has been logically
+/// filled with data, a region that has been initialized at some point but not
+/// yet logically filled, and a region at the end that is fully uninitialized.
+/// The filled region is guaranteed to be a subset of the initialized region.
+///
+/// In summary, the contents of the buffer can be visualized as:
+///
+/// ```not_rust
+/// [ capacity ]
+/// [ filled | unfilled ]
+/// [ initialized | uninitialized ]
+/// ```
+pub struct ReadBuf<'a> {
+ buf: &'a mut [MaybeUninit<u8>],
+ filled: usize,
+ initialized: usize,
+}
+
+impl<'a> ReadBuf<'a> {
+ /// Creates a new `ReadBuf` from a fully initialized buffer.
+ #[inline]
+ pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> {
+ let initialized = buf.len();
+ let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buf) };
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized,
+ }
+ }
+
+ /// Creates a new `ReadBuf` from a fully uninitialized buffer.
+ ///
+ /// Use `assume_init` if part of the buffer is known to be already inintialized.
+ #[inline]
+ pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> {
+ ReadBuf {
+ buf,
+ filled: 0,
+ initialized: 0,
+ }
+ }
+
+ /// Returns the total capacity of the buffer.
+ #[inline]
+ pub fn capacity(&self) -> usize {
+ self.buf.len()
+ }
+
+ /// Returns a shared reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled(&self) -> &[u8] {
+ let slice = &self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable.
+ unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the filled portion of the buffer.
+ #[inline]
+ pub fn filled_mut(&mut self) -> &mut [u8] {
+ let slice = &mut self.buf[..self.filled];
+ // safety: filled describes how far into the buffer that the
+ // user has filled with bytes, so it's been initialized.
+ // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns a new `ReadBuf` comprised of the unfilled section up to `n`.
+ #[inline]
+ pub fn take(&mut self, n: usize) -> ReadBuf<'_> {
+ let max = std::cmp::min(self.remaining(), n);
+ // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`.
+ unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) }
+ }
+
+ /// Returns a shared reference to the initialized portion of the buffer.
+ ///
+ /// This includes the filled portion.
+ #[inline]
+ pub fn initialized(&self) -> &[u8] {
+ let slice = &self.buf[..self.initialized];
+ // safety: initialized describes how far into the buffer that the
+ // user has at some point initialized with bytes.
+ // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable.
+ unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the initialized portion of the buffer.
+ ///
+ /// This includes the filled portion.
+ #[inline]
+ pub fn initialized_mut(&mut self) -> &mut [u8] {
+ let slice = &mut self.buf[..self.initialized];
+ // safety: initialized describes how far into the buffer that the
+ // user has at some point initialized with bytes.
+ // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully
+ /// initialized.
+ ///
+ /// # Safety
+ ///
+ /// The caller must not de-initialize portions of the buffer that have already been initialized.
+ #[inline]
+ pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit<u8>] {
+ &mut self.buf[self.filled..]
+ }
+
+ /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized.
+ ///
+ /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after
+ /// the first use.
+ #[inline]
+ pub fn initialize_unfilled(&mut self) -> &mut [u8] {
+ self.initialize_unfilled_to(self.remaining())
+ }
+
+ /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is
+ /// fully initialized.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `self.remaining()` is less than `n`.
+ #[inline]
+ pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] {
+ assert!(self.remaining() >= n, "n overflows remaining");
+
+ // This can't overflow, otherwise the assert above would have failed.
+ let end = self.filled + n;
+
+ if self.initialized < end {
+ unsafe {
+ self.buf[self.initialized..end]
+ .as_mut_ptr()
+ .write_bytes(0, end - self.initialized);
+ }
+ self.initialized = end;
+ }
+
+ let slice = &mut self.buf[self.filled..end];
+ // safety: just above, we checked that the end of the buf has
+ // been initialized to some value.
+ unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) }
+ }
+
+ /// Returns the number of bytes at the end of the slice that have not yet been filled.
+ #[inline]
+ pub fn remaining(&self) -> usize {
+ self.capacity() - self.filled
+ }
+
+ /// Clears the buffer, resetting the filled region to empty.
+ ///
+ /// The number of initialized bytes is not changed, and the contents of the buffer are not modified.
+ #[inline]
+ pub fn clear(&mut self) {
+ self.filled = 0;
+ }
+
+ /// Advances the size of the filled region of the buffer.
+ ///
+ /// The number of initialized bytes is not changed.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the filled region of the buffer would become larger than the initialized region.
+ #[inline]
+ pub fn advance(&mut self, n: usize) {
+ let new = self.filled.checked_add(n).expect("filled overflow");
+ self.set_filled(new);
+ }
+
+ /// Sets the size of the filled region of the buffer.
+ ///
+ /// The number of initialized bytes is not changed.
+ ///
+ /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for
+ /// example, by a `AsyncRead` implementation that compresses data in-place).
+ ///
+ /// # Panics
+ ///
+ /// Panics if the filled region of the buffer would become larger than the intialized region.
+ #[inline]
+ pub fn set_filled(&mut self, n: usize) {
+ assert!(
+ n <= self.initialized,
+ "filled must not become larger than initialized"
+ );
+ self.filled = n;
+ }
+
+ /// Asserts that the first `n` unfilled bytes of the buffer are initialized.
+ ///
+ /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer
+ /// bytes than are already known to be initialized.
+ ///
+ /// # Safety
+ ///
+ /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized.
+ #[inline]
+ pub unsafe fn assume_init(&mut self, n: usize) {
+ let new = self.filled + n;
+ if new > self.initialized {
+ self.initialized = new;
+ }
+ }
+
+ /// Appends data to the buffer, advancing the written position and possibly also the initialized position.
+ ///
+ /// # Panics
+ ///
+ /// Panics if `self.remaining()` is less than `buf.len()`.
+ #[inline]
+ pub fn put_slice(&mut self, buf: &[u8]) {
+ assert!(
+ self.remaining() >= buf.len(),
+ "buf.len() must fit in remaining()"
+ );
+
+ let amt = buf.len();
+ // Cannot overflow, asserted above
+ let end = self.filled + amt;
+
+ // Safety: the length is asserted above
+ unsafe {
+ self.buf[self.filled..end]
+ .as_mut_ptr()
+ .cast::<u8>()
+ .copy_from_nonoverlapping(buf.as_ptr(), amt);
+ }
+
+ if self.initialized < end {
+ self.initialized = end;
+ }
+ self.filled = end;
+ }
+}
+
+impl fmt::Debug for ReadBuf<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReadBuf")
+ .field("filled", &self.filled)
+ .field("initialized", &self.initialized)
+ .field("capacity", &self.capacity())
+ .finish()
+ }
+}
diff --git a/src/io/registration.rs b/src/io/registration.rs
index 77fe6db..ce6cffd 100644
--- a/src/io/registration.rs
+++ b/src/io/registration.rs
@@ -1,7 +1,7 @@
-use crate::io::driver::{platform, Direction, Handle};
-use crate::util::slab::Address;
+use crate::io::driver::{Direction, Handle, ReadyEvent, ScheduledIo};
+use crate::util::slab;
-use mio::{self, Evented};
+use mio::event::Source;
use std::io;
use std::task::{Context, Poll};
@@ -38,74 +38,38 @@ cfg_io_driver! {
/// [`poll_read_ready`]: method@Self::poll_read_ready`
/// [`poll_write_ready`]: method@Self::poll_write_ready`
#[derive(Debug)]
- pub struct Registration {
+ pub(crate) struct Registration {
+ /// Handle to the associated driver.
handle: Handle,
- address: Address,
+
+ /// Reference to state stored by the driver.
+ shared: slab::Ref<ScheduledIo>,
}
}
+unsafe impl Send for Registration {}
+unsafe impl Sync for Registration {}
+
// ===== impl Registration =====
impl Registration {
- /// Registers the I/O resource with the default reactor.
- ///
- /// # Return
- ///
- /// - `Ok` if the registration happened successfully
- /// - `Err` if an error was encountered during registration
- ///
- ///
- /// # Panics
- ///
- /// This function panics if thread-local runtime is not set.
- ///
- /// The runtime is usually set implicitly when this function is called
- /// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new<T>(io: &T) -> io::Result<Registration>
- where
- T: Evented,
- {
- Registration::new_with_ready(io, mio::Ready::all())
- }
-
- /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state.
- /// `new_with_ready` should be used over `new` when you need control over the readiness state,
+ /// Registers the I/O resource with the default reactor, for a specific `mio::Interest`.
+ /// `new_with_interest` should be used over `new` when you need control over the readiness state,
/// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if
/// you are interested in those states, you will need to add them to the readiness state passed
/// to this function.
///
- /// An example to listen to read only
- ///
- /// ```rust
- /// ##[cfg(unix)]
- /// mio::Ready::from_usize(
- /// mio::Ready::readable().as_usize()
- /// | mio::unix::UnixReady::error().as_usize()
- /// | mio::unix::UnixReady::hup().as_usize()
- /// );
- /// ```
- ///
/// # Return
///
/// - `Ok` if the registration happened successfully
/// - `Err` if an error was encountered during registration
- ///
- ///
- /// # Panics
- ///
- /// This function panics if thread-local runtime is not set.
- ///
- /// The runtime is usually set implicitly when this function is called
- /// from a future driven by a tokio runtime, otherwise runtime can be set
- /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function.
- pub fn new_with_ready<T>(io: &T, ready: mio::Ready) -> io::Result<Registration>
- where
- T: Evented,
- {
- let handle = Handle::current();
- let address = if let Some(inner) = handle.inner() {
- inner.add_source(io, ready)?
+ pub(crate) fn new_with_interest_and_handle(
+ io: &mut impl Source,
+ interest: mio::Interest,
+ handle: Handle,
+ ) -> io::Result<Registration> {
+ let shared = if let Some(inner) = handle.inner() {
+ inner.add_source(io, interest)?
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
@@ -113,7 +77,7 @@ impl Registration {
));
};
- Ok(Registration { handle, address })
+ Ok(Registration { handle, shared })
}
/// Deregisters the I/O resource from the reactor it is associated with.
@@ -132,10 +96,7 @@ impl Registration {
/// no longer result in notifications getting sent for this registration.
///
/// `Err` is returned if an error is encountered.
- pub fn deregister<T>(&mut self, io: &T) -> io::Result<()>
- where
- T: Evented,
- {
+ pub(super) fn deregister(&mut self, io: &mut impl Source) -> io::Result<()> {
let inner = match self.handle.inner() {
Some(inner) => inner,
None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")),
@@ -143,198 +104,47 @@ impl Registration {
inner.deregister_source(io)
}
- /// Polls for events on the I/O resource's read readiness stream.
- ///
- /// If the I/O resource receives a new read readiness event since the last
- /// call to `poll_read_ready`, it is returned. If it has not, the current
- /// task is notified once a new event is received.
- ///
- /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned,
- /// the function will always return `Ready(HUP)`. This should be treated as
- /// the end of the readiness stream.
- ///
- /// # Return value
- ///
- /// There are several possible return values:
- ///
- /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received
- /// a new readiness event. The readiness value is included.
- ///
- /// * `Poll::Pending` means that no new readiness events have been received
- /// since the last call to `poll_read_ready`.
- ///
- /// * `Poll::Ready(Err(err))` means that the registration has encountered an
- /// error. This could represent a permanent internal error for example.
- ///
- /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| {
- coop.made_progress();
- e
- })?;
- match v {
- Some(v) => {
- coop.made_progress();
- Poll::Ready(Ok(v))
- }
- None => Poll::Pending,
- }
- }
-
- /// Consume any pending read readiness event.
- ///
- /// This function is identical to [`poll_read_ready`] **except** that it
- /// will not notify the current task when a new event is received. As such,
- /// it is safe to call this function from outside of a task context.
- ///
- /// [`poll_read_ready`]: method@Self::poll_read_ready
- pub fn take_read_ready(&self) -> io::Result<Option<mio::Ready>> {
- self.poll_ready(Direction::Read, None)
- }
-
- /// Polls for events on the I/O resource's write readiness stream.
- ///
- /// If the I/O resource receives a new write readiness event since the last
- /// call to `poll_write_ready`, it is returned. If it has not, the current
- /// task is notified once a new event is received.
- ///
- /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned,
- /// the function will always return `Ready(HUP)`. This should be treated as
- /// the end of the readiness stream.
- ///
- /// # Return value
- ///
- /// There are several possible return values:
- ///
- /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received
- /// a new readiness event. The readiness value is included.
- ///
- /// * `Poll::Pending` means that no new readiness events have been received
- /// since the last call to `poll_write_ready`.
- ///
- /// * `Poll::Ready(Err(err))` means that the registration has encountered an
- /// error. This could represent a permanent internal error for example.
- ///
- /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered
- ///
- /// # Panics
- ///
- /// This function will panic if called from outside of a task context.
- pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> {
- // Keep track of task budget
- let coop = ready!(crate::coop::poll_proceed(cx));
-
- let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| {
- coop.made_progress();
- e
- })?;
- match v {
- Some(v) => {
- coop.made_progress();
- Poll::Ready(Ok(v))
- }
- None => Poll::Pending,
- }
- }
-
- /// Consumes any pending write readiness event.
- ///
- /// This function is identical to [`poll_write_ready`] **except** that it
- /// will not notify the current task when a new event is received. As such,
- /// it is safe to call this function from outside of a task context.
- ///
- /// [`poll_write_ready`]: method@Self::poll_write_ready
- pub fn take_write_ready(&self) -> io::Result<Option<mio::Ready>> {
- self.poll_ready(Direction::Write, None)
+ pub(super) fn clear_readiness(&self, event: ReadyEvent) {
+ self.shared.clear_readiness(event);
}
/// Polls for events on the I/O resource's `direction` readiness stream.
///
/// If called with a task context, notify the task when a new event is
/// received.
- fn poll_ready(
+ pub(super) fn poll_readiness(
&self,
+ cx: &mut Context<'_>,
direction: Direction,
- cx: Option<&mut Context<'_>>,
- ) -> io::Result<Option<mio::Ready>> {
- let inner = match self.handle.inner() {
- Some(inner) => inner,
- None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")),
- };
-
- // If the task should be notified about new events, ensure that it has
- // been registered
- if let Some(ref cx) = cx {
- inner.register(self.address, direction, cx.waker().clone())
+ ) -> Poll<io::Result<ReadyEvent>> {
+ if self.handle.inner().is_none() {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone")));
}
- let mask = direction.mask();
- let mask_no_hup = (mask - platform::hup() - platform::error()).as_usize();
-
- let sched = inner.io_dispatch.get(self.address).unwrap();
+ // Keep track of task budget
+ let coop = ready!(crate::coop::poll_proceed(cx));
+ let ev = ready!(self.shared.poll_readiness(cx, direction));
+ coop.made_progress();
+ Poll::Ready(Ok(ev))
+ }
+}
- // This consumes the current readiness state **except** for HUP and
- // error. HUP and error are excluded because a) they are final states
- // and never transitition out and b) both the read AND the write
- // directions need to be able to obvserve these states.
- //
- // # Platform-specific behavior
- //
- // HUP and error readiness are platform-specific. On epoll platforms,
- // HUP has specific conditions that must be met by both peers of a
- // connection in order to be triggered.
- //
- // On epoll platforms, `EPOLLERR` is signaled through
- // `UnixReady::error()` and is important to be observable by both read
- // AND write. A specific case that `EPOLLERR` occurs is when the read
- // end of a pipe is closed. When this occurs, a peer blocked by
- // writing to the pipe should be notified.
- let curr_ready = sched
- .set_readiness(self.address, |curr| curr & (!mask_no_hup))
- .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address));
+cfg_io_readiness! {
+ impl Registration {
+ pub(super) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> {
+ use std::future::Future;
+ use std::pin::Pin;
- let mut ready = mask & mio::Ready::from_usize(curr_ready);
+ let fut = self.shared.readiness(interest);
+ pin!(fut);
- if ready.is_empty() {
- if let Some(cx) = cx {
- // Update the task info
- match direction {
- Direction::Read => sched.reader.register_by_ref(cx.waker()),
- Direction::Write => sched.writer.register_by_ref(cx.waker()),
+ crate::future::poll_fn(|cx| {
+ if self.handle.inner().is_none() {
+ return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone")));
}
- // Try again
- let curr_ready = sched
- .set_readiness(self.address, |curr| curr & (!mask_no_hup))
- .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address));
- ready = mask & mio::Ready::from_usize(curr_ready);
- }
- }
-
- if ready.is_empty() {
- Ok(None)
- } else {
- Ok(Some(ready))
+ Pin::new(&mut fut).poll(cx).map(Ok)
+ }).await
}
}
}
-
-unsafe impl Send for Registration {}
-unsafe impl Sync for Registration {}
-
-impl Drop for Registration {
- fn drop(&mut self) {
- let inner = match self.handle.inner() {
- Some(inner) => inner,
- None => return,
- };
- inner.drop_source(self.address);
- }
-}
diff --git a/src/io/seek.rs b/src/io/seek.rs
index e3b5bf6..e64205d 100644
--- a/src/io/seek.rs
+++ b/src/io/seek.rs
@@ -1,15 +1,23 @@
use crate::io::AsyncSeek;
+
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io::{self, SeekFrom};
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-/// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method.
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-pub struct Seek<'a, S: ?Sized> {
- seek: &'a mut S,
- pos: Option<SeekFrom>,
+pin_project! {
+ /// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method.
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct Seek<'a, S: ?Sized> {
+ seek: &'a mut S,
+ pos: Option<SeekFrom>,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
pub(crate) fn seek<S>(seek: &mut S, pos: SeekFrom) -> Seek<'_, S>
@@ -19,6 +27,7 @@ where
Seek {
seek,
pos: Some(pos),
+ _pin: PhantomPinned,
}
}
@@ -28,29 +37,21 @@ where
{
type Output = io::Result<u64>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
match me.pos {
- Some(pos) => match Pin::new(&mut me.seek).start_seek(cx, pos) {
- Poll::Ready(Ok(())) => {
- me.pos = None;
- Pin::new(&mut me.seek).poll_complete(cx)
+ Some(pos) => {
+ // ensure no seek in progress
+ ready!(Pin::new(&mut *me.seek).poll_complete(cx))?;
+ match Pin::new(&mut *me.seek).start_seek(*pos) {
+ Ok(()) => {
+ *me.pos = None;
+ Pin::new(&mut *me.seek).poll_complete(cx)
+ }
+ Err(e) => Poll::Ready(Err(e)),
}
- Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
- Poll::Pending => Poll::Pending,
- },
- None => Pin::new(&mut me.seek).poll_complete(cx),
+ }
+ None => Pin::new(&mut *me.seek).poll_complete(cx),
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Seek<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/split.rs b/src/io/split.rs
index 134b937..fd3273e 100644
--- a/src/io/split.rs
+++ b/src/io/split.rs
@@ -4,9 +4,8 @@
//! To restore this read/write object from its `split::ReadHalf` and
//! `split::WriteHalf` use `unsplit`.
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
-use bytes::{Buf, BufMut};
use std::cell::UnsafeCell;
use std::fmt;
use std::io;
@@ -102,20 +101,11 @@ impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
let mut inner = ready!(self.inner.poll_lock(cx));
inner.stream_pin().poll_read(cx, buf)
}
-
- fn poll_read_buf<B: BufMut>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- let mut inner = ready!(self.inner.poll_lock(cx));
- inner.stream_pin().poll_read_buf(cx, buf)
- }
}
impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
@@ -137,15 +127,6 @@ impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
let mut inner = ready!(self.inner.poll_lock(cx));
inner.stream_pin().poll_shutdown(cx)
}
-
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<Result<usize, io::Error>> {
- let mut inner = ready!(self.inner.poll_lock(cx));
- inner.stream_pin().poll_write_buf(cx, buf)
- }
}
impl<T> Inner<T> {
diff --git a/src/io/stderr.rs b/src/io/stderr.rs
index 99607dc..2f624fb 100644
--- a/src/io/stderr.rs
+++ b/src/io/stderr.rs
@@ -1,4 +1,5 @@
use crate::io::blocking::Blocking;
+use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows;
use crate::io::AsyncWrite;
use std::io;
@@ -35,7 +36,7 @@ cfg_io_std! {
/// ```
#[derive(Debug)]
pub struct Stderr {
- std: Blocking<std::io::Stderr>,
+ std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stderr>>,
}
/// Constructs a new handle to the standard error of the current process.
@@ -59,7 +60,7 @@ cfg_io_std! {
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
- /// let mut stderr = io::stdout();
+ /// let mut stderr = io::stderr();
/// stderr.write_all(b"Print some error here.").await?;
/// Ok(())
/// }
@@ -67,7 +68,7 @@ cfg_io_std! {
pub fn stderr() -> Stderr {
let std = io::stderr();
Stderr {
- std: Blocking::new(std),
+ std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)),
}
}
}
diff --git a/src/io/stdin.rs b/src/io/stdin.rs
index 325b875..c9578f1 100644
--- a/src/io/stdin.rs
+++ b/src/io/stdin.rs
@@ -1,5 +1,5 @@
use crate::io::blocking::Blocking;
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use std::io;
use std::pin::Pin;
@@ -63,16 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin {
}
impl AsyncRead for Stdin {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97
- false
- }
-
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
Pin::new(&mut self.std).poll_read(cx, buf)
}
}
diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs
new file mode 100644
index 0000000..d21c842
--- /dev/null
+++ b/src/io/stdio_common.rs
@@ -0,0 +1,220 @@
+//! Contains utilities for stdout and stderr.
+use crate::io::AsyncWrite;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+/// # Windows
+/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest,
+/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF.
+/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
+/// # Other platforms
+/// passes data to `inner` as is
+#[derive(Debug)]
+pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
+ inner: W,
+}
+
+impl<W> SplitByUtf8BoundaryIfWindows<W> {
+ pub(crate) fn new(inner: W) -> Self {
+ Self { inner }
+ }
+}
+
+// this constant is defined by Unicode standard.
+const MAX_BYTES_PER_CHAR: usize = 4;
+
+// Subject for tweaking here
+const MAGIC_CONST: usize = 8;
+
+impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
+where
+ W: AsyncWrite + Unpin,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ mut buf: &[u8],
+ ) -> Poll<Result<usize, std::io::Error>> {
+ // just a closure to avoid repetitive code
+ let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
+
+ // 1. Only windows stdio can suffer from non-utf8.
+ // We also check for `test` so that we can write some tests
+ // for further code. Since `AsyncWrite` can always shrink
+ // buffer at its discretion, excessive (i.e. in tests) shrinking
+ // does not break correctness.
+ // 2. If buffer is small, it will not be shrinked.
+ // That's why, it's "textness" will not change, so we don't have
+ // to fixup it.
+ if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
+ {
+ return call_inner(buf);
+ }
+
+ buf = &buf[..crate::io::blocking::MAX_BUF];
+
+ // Now there are two possibilites.
+ // If caller gave is binary buffer, we **should not** shrink it
+ // anymore, because excessive shrinking hits performance.
+ // If caller gave as binary buffer, we **must** additionaly
+ // shrink it to strip incomplete char at the end of buffer.
+ // that's why check we will perform now is allowed to have
+ // false-positive.
+
+ // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
+ // if they are (possibly incomplete) utf8, then we can be quite sure
+ // that input buffer was utf8.
+
+ let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
+ Ok(_) => true,
+ Err(err) => {
+ let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
+ incomplete_bytes < MAX_BYTES_PER_CHAR
+ }
+ };
+
+ if have_to_fix_up {
+ // We must pop several bytes at the end which form incomplete
+ // character. To achieve it, we exploit UTF8 encoding:
+ // for any code point, all bytes except first start with 0b10 prefix.
+ // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
+ let trailing_incomplete_char_size = buf
+ .iter()
+ .rev()
+ .take(MAX_BYTES_PER_CHAR)
+ .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
+ .unwrap_or(0)
+ + 1;
+ buf = &buf[..buf.len() - trailing_incomplete_char_size];
+ }
+
+ call_inner(buf)
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ Pin::new(&mut self.inner).poll_flush(cx)
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Result<(), std::io::Error>> {
+ Pin::new(&mut self.inner).poll_shutdown(cx)
+ }
+}
+
+#[cfg(test)]
+#[cfg(not(loom))]
+mod tests {
+ use crate::io::AsyncWriteExt;
+ use std::io;
+ use std::pin::Pin;
+ use std::task::Context;
+ use std::task::Poll;
+
+ const MAX_BUF: usize = 16 * 1024;
+
+ struct TextMockWriter;
+
+ impl crate::io::AsyncWrite for TextMockWriter {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ assert!(buf.len() <= MAX_BUF);
+ assert!(std::str::from_utf8(buf).is_ok());
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ struct LoggingMockWriter {
+ write_history: Vec<usize>,
+ }
+
+ impl LoggingMockWriter {
+ fn new() -> Self {
+ LoggingMockWriter {
+ write_history: Vec::new(),
+ }
+ }
+ }
+
+ impl crate::io::AsyncWrite for LoggingMockWriter {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ assert!(buf.len() <= MAX_BUF);
+ self.write_history.push(buf.len());
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ self: Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ #[test]
+ fn test_splitter() {
+ let data = str::repeat("â–ˆ", MAX_BUF);
+ let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
+ let fut = async move {
+ wr.write_all(data.as_bytes()).await.unwrap();
+ };
+ crate::runtime::Builder::new_current_thread()
+ .build()
+ .unwrap()
+ .block_on(fut);
+ }
+
+ #[test]
+ fn test_pseudo_text() {
+ // In this test we write a piece of binary data, whose beginning is
+ // text though. We then validate that even in this corner case buffer
+ // was not shrinked too much.
+ let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
+ let mut data: Vec<u8> = str::repeat("a", checked_count).into();
+ data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
+ let mut writer = LoggingMockWriter::new();
+ let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
+ crate::runtime::Builder::new_current_thread()
+ .build()
+ .unwrap()
+ .block_on(async {
+ splitter.write_all(&data).await.unwrap();
+ });
+ // Check that at most two writes were performed
+ assert!(writer.write_history.len() <= 2);
+ // Check that all has been written
+ assert_eq!(
+ writer.write_history.iter().copied().sum::<usize>(),
+ data.len()
+ );
+ // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked
+ // from the buffer: one because it was outside of MAX_BUF boundary, and
+ // up to one "utf8 code point".
+ assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
+ }
+}
diff --git a/src/io/stdout.rs b/src/io/stdout.rs
index 5377993..a08ed01 100644
--- a/src/io/stdout.rs
+++ b/src/io/stdout.rs
@@ -1,6 +1,6 @@
use crate::io::blocking::Blocking;
+use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows;
use crate::io::AsyncWrite;
-
use std::io;
use std::pin::Pin;
use std::task::Context;
@@ -35,7 +35,7 @@ cfg_io_std! {
/// ```
#[derive(Debug)]
pub struct Stdout {
- std: Blocking<std::io::Stdout>,
+ std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stdout>>,
}
/// Constructs a new handle to the standard output of the current process.
@@ -67,7 +67,7 @@ cfg_io_std! {
pub fn stdout() -> Stdout {
let std = io::stdout();
Stdout {
- std: Blocking::new(std),
+ std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)),
}
}
}
diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs
index 1bfab90..9e87f2f 100644
--- a/src/io/util/async_buf_read_ext.rs
+++ b/src/io/util/async_buf_read_ext.rs
@@ -14,7 +14,7 @@ cfg_io_util! {
/// Equivalent to:
///
/// ```ignore
- /// async fn read_until(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>;
+ /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>;
/// ```
///
/// This function will read bytes from the underlying stream until the
diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs
index e848a5d..0ab66c2 100644
--- a/src/io/util/async_read_ext.rs
+++ b/src/io/util/async_read_ext.rs
@@ -986,10 +986,12 @@ cfg_io_util! {
///
/// All bytes read from this source will be appended to the specified
/// buffer `buf`. This function will continuously call [`read()`] to
- /// append more data to `buf` until [`read()`][read] returns `Ok(0)`.
+ /// append more data to `buf` until [`read()`] returns `Ok(0)`.
///
/// If successful, the total number of bytes read is returned.
///
+ /// [`read()`]: AsyncReadExt::read
+ ///
/// # Errors
///
/// If a read error is encountered then the `read_to_end` operation
@@ -1018,7 +1020,7 @@ cfg_io_util! {
/// (See also the [`tokio::fs::read`] convenience function for reading from a
/// file.)
///
- /// [`tokio::fs::read`]: crate::fs::read::read
+ /// [`tokio::fs::read`]: fn@crate::fs::read
fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self>
where
Self: Unpin,
@@ -1065,7 +1067,7 @@ cfg_io_util! {
/// (See also the [`crate::fs::read_to_string`] convenience function for
/// reading from a file.)
///
- /// [`crate::fs::read_to_string`]: crate::fs::read_to_string::read_to_string
+ /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string
fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self>
where
Self: Unpin,
@@ -1078,7 +1080,11 @@ cfg_io_util! {
/// This function returns a new instance of `AsyncRead` which will read
/// at most `limit` bytes, after which it will always return EOF
/// (`Ok(0)`). Any read errors will not count towards the number of
- /// bytes read and future calls to [`read()`][read] may succeed.
+ /// bytes read and future calls to [`read()`] may succeed.
+ ///
+ /// [`read()`]: fn@crate::io::AsyncReadExt::read
+ ///
+ /// [read]: AsyncReadExt::read
///
/// # Examples
///
diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs
index c7a0f72..351900b 100644
--- a/src/io/util/async_seek_ext.rs
+++ b/src/io/util/async_seek_ext.rs
@@ -2,65 +2,73 @@ use crate::io::seek::{seek, Seek};
use crate::io::AsyncSeek;
use std::io::SeekFrom;
-/// An extension trait which adds utility methods to [`AsyncSeek`] types.
-///
-/// As a convenience, this trait may be imported using the [`prelude`]:
-///
-/// # Examples
-///
-/// ```
-/// use std::io::{Cursor, SeekFrom};
-/// use tokio::prelude::*;
-///
-/// #[tokio::main]
-/// async fn main() -> io::Result<()> {
-/// let mut cursor = Cursor::new(b"abcdefg");
-///
-/// // the `seek` method is defined by this trait
-/// cursor.seek(SeekFrom::Start(3)).await?;
-///
-/// let mut buf = [0; 1];
-/// let n = cursor.read(&mut buf).await?;
-/// assert_eq!(n, 1);
-/// assert_eq!(buf, [b'd']);
-///
-/// Ok(())
-/// }
-/// ```
-///
-/// See [module][crate::io] documentation for more details.
-///
-/// [`AsyncSeek`]: AsyncSeek
-/// [`prelude`]: crate::prelude
-pub trait AsyncSeekExt: AsyncSeek {
- /// Creates a future which will seek an IO object, and then yield the
- /// new position in the object and the object itself.
+cfg_io_util! {
+ /// An extension trait which adds utility methods to [`AsyncSeek`] types.
///
- /// In the case of an error the buffer and the object will be discarded, with
- /// the error yielded.
+ /// As a convenience, this trait may be imported using the [`prelude`]:
///
/// # Examples
///
- /// ```no_run
- /// use tokio::fs::File;
+ /// ```
+ /// use std::io::{Cursor, SeekFrom};
/// use tokio::prelude::*;
///
- /// use std::io::SeekFrom;
+ /// #[tokio::main]
+ /// async fn main() -> io::Result<()> {
+ /// let mut cursor = Cursor::new(b"abcdefg");
+ ///
+ /// // the `seek` method is defined by this trait
+ /// cursor.seek(SeekFrom::Start(3)).await?;
///
- /// # async fn dox() -> std::io::Result<()> {
- /// let mut file = File::open("foo.txt").await?;
- /// file.seek(SeekFrom::Start(6)).await?;
+ /// let mut buf = [0; 1];
+ /// let n = cursor.read(&mut buf).await?;
+ /// assert_eq!(n, 1);
+ /// assert_eq!(buf, [b'd']);
///
- /// let mut contents = vec![0u8; 10];
- /// file.read_exact(&mut contents).await?;
- /// # Ok(())
- /// # }
+ /// Ok(())
+ /// }
/// ```
- fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self>
- where
- Self: Unpin,
- {
- seek(self, pos)
+ ///
+ /// See [module][crate::io] documentation for more details.
+ ///
+ /// [`AsyncSeek`]: AsyncSeek
+ /// [`prelude`]: crate::prelude
+ pub trait AsyncSeekExt: AsyncSeek {
+ /// Creates a future which will seek an IO object, and then yield the
+ /// new position in the object and the object itself.
+ ///
+ /// Equivalent to:
+ ///
+ /// ```ignore
+ /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>;
+ /// ```
+ ///
+ /// In the case of an error the buffer and the object will be discarded, with
+ /// the error yielded.
+ ///
+ /// # Examples
+ ///
+ /// ```no_run
+ /// use tokio::fs::File;
+ /// use tokio::prelude::*;
+ ///
+ /// use std::io::SeekFrom;
+ ///
+ /// # async fn dox() -> std::io::Result<()> {
+ /// let mut file = File::open("foo.txt").await?;
+ /// file.seek(SeekFrom::Start(6)).await?;
+ ///
+ /// let mut contents = vec![0u8; 10];
+ /// file.read_exact(&mut contents).await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self>
+ where
+ Self: Unpin,
+ {
+ seek(self, pos)
+ }
}
}
diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs
index fa41097..e6ef5b2 100644
--- a/src/io/util/async_write_ext.rs
+++ b/src/io/util/async_write_ext.rs
@@ -119,6 +119,7 @@ cfg_io_util! {
write(self, src)
}
+
/// Writes a buffer into this writer, advancing the buffer's internal
/// cursor.
///
@@ -134,7 +135,7 @@ cfg_io_util! {
/// internal cursor is advanced by the number of bytes written. A
/// subsequent call to `write_buf` using the **same** `buf` value will
/// resume from the point that the first call to `write_buf` completed.
- /// A call to `write` represents *at most one* attempt to write to any
+ /// A call to `write_buf` represents *at most one* attempt to write to any
/// wrapped object.
///
/// # Return
@@ -976,6 +977,8 @@ cfg_io_util! {
/// no longer attempt to write to the stream. For example, the
/// `TcpStream` implementation will issue a `shutdown(Write)` sys call.
///
+ /// [`flush`]: fn@crate::io::AsyncWriteExt::flush
+ ///
/// # Examples
///
/// ```no_run
diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs
index a1c5990..271f61b 100644
--- a/src/io/util/buf_reader.rs
+++ b/src/io/util/buf_reader.rs
@@ -1,10 +1,8 @@
use crate::io::util::DEFAULT_BUF_SIZE;
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
-use bytes::Buf;
use pin_project_lite::pin_project;
-use std::io::{self, Read};
-use std::mem::MaybeUninit;
+use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, fmt};
@@ -44,21 +42,12 @@ impl<R: AsyncRead> BufReader<R> {
/// Creates a new `BufReader` with the specified buffer capacity.
pub fn with_capacity(capacity: usize, inner: R) -> Self {
- unsafe {
- let mut buffer = Vec::with_capacity(capacity);
- buffer.set_len(capacity);
-
- {
- // Convert to MaybeUninit
- let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]);
- inner.prepare_uninitialized_buffer(b);
- }
- Self {
- inner,
- buf: buffer.into_boxed_slice(),
- pos: 0,
- cap: 0,
- }
+ let buffer = vec![0; capacity];
+ Self {
+ inner,
+ buf: buffer.into_boxed_slice(),
+ pos: 0,
+ cap: 0,
}
}
@@ -110,25 +99,21 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
// If we don't have any buffered data and we're doing a massive read
// (larger than our internal buffer), bypass our internal buffer
// entirely.
- if self.pos == self.cap && buf.len() >= self.buf.len() {
+ if self.pos == self.cap && buf.remaining() >= self.buf.len() {
let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf));
self.discard_buffer();
return Poll::Ready(res);
}
- let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?;
- let nread = rem.read(buf)?;
- self.consume(nread);
- Poll::Ready(Ok(nread))
- }
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
+ let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
+ let amt = std::cmp::min(rem.len(), buf.remaining());
+ buf.put_slice(&rem[..amt]);
+ self.consume(amt);
+ Poll::Ready(Ok(()))
}
}
@@ -142,7 +127,9 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
// to tell the compiler that the pos..cap slice is always valid.
if *me.pos >= *me.cap {
debug_assert!(*me.pos == *me.cap);
- *me.cap = ready!(me.inner.poll_read(cx, me.buf))?;
+ let mut buf = ReadBuf::new(me.buf);
+ ready!(me.inner.poll_read(cx, &mut buf))?;
+ *me.cap = buf.filled().len();
*me.pos = 0;
}
Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
@@ -163,14 +150,6 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
self.get_pin_mut().poll_write(cx, buf)
}
- fn poll_write_buf<B: Buf>(
- self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut B,
- ) -> Poll<io::Result<usize>> {
- self.get_pin_mut().poll_write_buf(cx, buf)
- }
-
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_flush(cx)
}
diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs
index a56a451..cc857e2 100644
--- a/src/io/util/buf_stream.rs
+++ b/src/io/util/buf_stream.rs
@@ -1,9 +1,8 @@
use crate::io::util::{BufReader, BufWriter};
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project;
use std::io;
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -137,15 +136,10 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
- }
}
impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> {
diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs
index efd053e..5e3d4b7 100644
--- a/src/io/util/buf_writer.rs
+++ b/src/io/util/buf_writer.rs
@@ -1,10 +1,9 @@
use crate::io::util::DEFAULT_BUF_SIZE;
-use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite};
+use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project;
use std::fmt;
use std::io::{self, Write};
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -147,15 +146,10 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_read(cx, buf)
}
-
- // we can't skip unconditionally because of the large buffer case in read.
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.get_ref().prepare_uninitialized_buffer(buf)
- }
}
impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
diff --git a/src/io/util/chain.rs b/src/io/util/chain.rs
index 8ba9194..84f37fc 100644
--- a/src/io/util/chain.rs
+++ b/src/io/util/chain.rs
@@ -1,4 +1,4 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use pin_project_lite::pin_project;
use std::fmt;
@@ -84,26 +84,20 @@ where
T: AsyncRead,
U: AsyncRead,
{
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- if self.first.prepare_uninitialized_buffer(buf) {
- return true;
- }
- if self.second.prepare_uninitialized_buffer(buf) {
- return true;
- }
- false
- }
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
let me = self.project();
if !*me.done_first {
- match ready!(me.first.poll_read(cx, buf)?) {
- 0 if !buf.is_empty() => *me.done_first = true,
- n => return Poll::Ready(Ok(n)),
+ let rem = buf.remaining();
+ ready!(me.first.poll_read(cx, buf))?;
+ if buf.remaining() == rem {
+ *me.done_first = true;
+ } else {
+ return Poll::Ready(Ok(()));
}
}
me.second.poll_read(cx, buf)
diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs
index 7bfe296..c5981cf 100644
--- a/src/io/util/copy.rs
+++ b/src/io/util/copy.rs
@@ -1,30 +1,25 @@
-use crate::io::{AsyncRead, AsyncWrite};
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
- /// A future that asynchronously copies the entire contents of a reader into a
- /// writer.
- ///
- /// This struct is generally created by calling [`copy`][copy]. Please
- /// see the documentation of `copy()` for more details.
- ///
- /// [copy]: copy()
- #[derive(Debug)]
- #[must_use = "futures do nothing unless you `.await` or poll them"]
- pub struct Copy<'a, R: ?Sized, W: ?Sized> {
- reader: &'a mut R,
- read_done: bool,
- writer: &'a mut W,
- pos: usize,
- cap: usize,
- amt: u64,
- buf: Box<[u8]>,
- }
+/// A future that asynchronously copies the entire contents of a reader into a
+/// writer.
+#[derive(Debug)]
+#[must_use = "futures do nothing unless you `.await` or poll them"]
+struct Copy<'a, R: ?Sized, W: ?Sized> {
+ reader: &'a mut R,
+ read_done: bool,
+ writer: &'a mut W,
+ pos: usize,
+ cap: usize,
+ amt: u64,
+ buf: Box<[u8]>,
+}
+cfg_io_util! {
/// Asynchronously copies the entire contents of a reader into a writer.
///
/// This function returns a future that will continuously read data from
@@ -58,7 +53,7 @@ cfg_io_util! {
/// # Ok(())
/// # }
/// ```
- pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W>
+ pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
@@ -71,7 +66,7 @@ cfg_io_util! {
pos: 0,
cap: 0,
buf: vec![0; 2048].into_boxed_slice(),
- }
+ }.await
}
}
@@ -88,7 +83,9 @@ where
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
- let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?;
+ let mut buf = ReadBuf::new(&mut me.buf);
+ ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
+ let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
@@ -122,14 +119,3 @@ where
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/copy_buf.rs b/src/io/util/copy_buf.rs
new file mode 100644
index 0000000..6831580
--- /dev/null
+++ b/src/io/util/copy_buf.rs
@@ -0,0 +1,102 @@
+use crate::io::{AsyncBufRead, AsyncWrite};
+use std::future::Future;
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+cfg_io_util! {
+ /// A future that asynchronously copies the entire contents of a reader into a
+ /// writer.
+ ///
+ /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please
+ /// see the documentation of `copy_buf()` for more details.
+ ///
+ /// [copy_buf]: copy_buf()
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ struct CopyBuf<'a, R: ?Sized, W: ?Sized> {
+ reader: &'a mut R,
+ writer: &'a mut W,
+ amt: u64,
+ }
+
+ /// Asynchronously copies the entire contents of a reader into a writer.
+ ///
+ /// This function returns a future that will continuously read data from
+ /// `reader` and then write it into `writer` in a streaming fashion until
+ /// `reader` returns EOF.
+ ///
+ /// On success, the total number of bytes that were copied from `reader` to
+ /// `writer` is returned.
+ ///
+ ///
+ /// # Errors
+ ///
+ /// The returned future will finish with an error will return an error
+ /// immediately if any call to `poll_fill_buf` or `poll_write` returns an
+ /// error.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio::io;
+ ///
+ /// # async fn dox() -> std::io::Result<()> {
+ /// let mut reader: &[u8] = b"hello";
+ /// let mut writer: Vec<u8> = vec![];
+ ///
+ /// io::copy_buf(&mut reader, &mut writer).await?;
+ ///
+ /// assert_eq!(b"hello", &writer[..]);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
+ where
+ R: AsyncBufRead + Unpin + ?Sized,
+ W: AsyncWrite + Unpin + ?Sized,
+ {
+ CopyBuf {
+ reader,
+ writer,
+ amt: 0,
+ }.await
+ }
+}
+
+impl<R, W> Future for CopyBuf<'_, R, W>
+where
+ R: AsyncBufRead + Unpin + ?Sized,
+ W: AsyncWrite + Unpin + ?Sized,
+{
+ type Output = io::Result<u64>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ loop {
+ let me = &mut *self;
+ let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?;
+ if buffer.is_empty() {
+ ready!(Pin::new(&mut self.writer).poll_flush(cx))?;
+ return Poll::Ready(Ok(self.amt));
+ }
+
+ let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?;
+ if i == 0 {
+ return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
+ }
+ self.amt += i as u64;
+ Pin::new(&mut *self.reader).consume(i);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn assert_unpin() {
+ use std::marker::PhantomPinned;
+ crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>();
+ }
+}
diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs
index 576058d..f964d18 100644
--- a/src/io/util/empty.rs
+++ b/src/io/util/empty.rs
@@ -1,4 +1,4 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use std::fmt;
use std::io;
@@ -47,16 +47,13 @@ cfg_io_util! {
}
impl AsyncRead for Empty {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- false
- }
#[inline]
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
- _: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- Poll::Ready(Ok(0))
+ _: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/flush.rs b/src/io/util/flush.rs
index 534a516..88d60b8 100644
--- a/src/io/util/flush.rs
+++ b/src/io/util/flush.rs
@@ -1,18 +1,24 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future used to fully flush an I/O object.
///
/// Created by the [`AsyncWriteExt::flush`][flush] function.
/// [flush]: crate::io::AsyncWriteExt::flush
#[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Flush<'a, A: ?Sized> {
a: &'a mut A,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +27,10 @@ pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A>
where
A: AsyncWrite + Unpin + ?Sized,
{
- Flush { a }
+ Flush {
+ a,
+ _pin: PhantomPinned,
+ }
}
impl<A> Future for Flush<'_, A>
@@ -30,19 +39,8 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
Pin::new(&mut *me.a).poll_flush(cx)
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Flush<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs
index ee27400..b41f04a 100644
--- a/src/io/util/lines.rs
+++ b/src/io/util/lines.rs
@@ -83,8 +83,7 @@ impl<R> Lines<R>
where
R: AsyncBufRead,
{
- #[doc(hidden)]
- pub fn poll_next_line(
+ fn poll_next_line(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<String>>> {
diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs
new file mode 100644
index 0000000..e91a932
--- /dev/null
+++ b/src/io/util/mem.rs
@@ -0,0 +1,223 @@
+//! In-process memory IO types.
+
+use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
+use crate::loom::sync::Mutex;
+
+use bytes::{Buf, BytesMut};
+use std::{
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll, Waker},
+};
+
+/// A bidirectional pipe to read and write bytes in memory.
+///
+/// A pair of `DuplexStream`s are created together, and they act as a "channel"
+/// that can be used as in-memory IO types. Writing to one of the pairs will
+/// allow that data to be read from the other, and vice versa.
+///
+/// # Example
+///
+/// ```
+/// # async fn ex() -> std::io::Result<()> {
+/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
+/// let (mut client, mut server) = tokio::io::duplex(64);
+///
+/// client.write_all(b"ping").await?;
+///
+/// let mut buf = [0u8; 4];
+/// server.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"ping");
+///
+/// server.write_all(b"pong").await?;
+///
+/// client.read_exact(&mut buf).await?;
+/// assert_eq!(&buf, b"pong");
+/// # Ok(())
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct DuplexStream {
+ read: Arc<Mutex<Pipe>>,
+ write: Arc<Mutex<Pipe>>,
+}
+
+/// A unidirectional IO over a piece of memory.
+///
+/// Data can be written to the pipe, and reading will return that data.
+#[derive(Debug)]
+struct Pipe {
+ /// The buffer storing the bytes written, also read from.
+ ///
+ /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
+ /// functionality already. Additionally, it can try to copy data in the
+ /// same buffer if there read index has advanced far enough.
+ buffer: BytesMut,
+ /// Determines if the write side has been closed.
+ is_closed: bool,
+ /// The maximum amount of bytes that can be written before returning
+ /// `Poll::Pending`.
+ max_buf_size: usize,
+ /// If the `read` side has been polled and is pending, this is the waker
+ /// for that parked task.
+ read_waker: Option<Waker>,
+ /// If the `write` side has filled the `max_buf_size` and returned
+ /// `Poll::Pending`, this is the waker for that parked task.
+ write_waker: Option<Waker>,
+}
+
+// ===== impl DuplexStream =====
+
+/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
+///
+/// The `max_buf_size` argument is the maximum amount of bytes that can be
+/// written to a side before the write returns `Poll::Pending`.
+pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
+ let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+ let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
+
+ (
+ DuplexStream {
+ read: one.clone(),
+ write: two.clone(),
+ },
+ DuplexStream {
+ read: two,
+ write: one,
+ },
+ )
+}
+
+impl AsyncRead for DuplexStream {
+ // Previous rustc required this `self` to be `mut`, even though newer
+ // versions recognize it isn't needed to call `lock()`. So for
+ // compatibility, we include the `mut` and `allow` the lint.
+ //
+ // See https://github.com/rust-lang/rust/issues/73592
+ #[allow(unused_mut)]
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
+ }
+}
+
+impl AsyncWrite for DuplexStream {
+ #[allow(unused_mut)]
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock()).poll_flush(cx)
+ }
+
+ #[allow(unused_mut)]
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
+ }
+}
+
+impl Drop for DuplexStream {
+ fn drop(&mut self) {
+ // notify the other side of the closure
+ self.write.lock().close();
+ }
+}
+
+// ===== impl Pipe =====
+
+impl Pipe {
+ fn new(max_buf_size: usize) -> Self {
+ Pipe {
+ buffer: BytesMut::new(),
+ is_closed: false,
+ max_buf_size,
+ read_waker: None,
+ write_waker: None,
+ }
+ }
+
+ fn close(&mut self) {
+ self.is_closed = true;
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ }
+}
+
+impl AsyncRead for Pipe {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ if self.buffer.has_remaining() {
+ let max = self.buffer.remaining().min(buf.remaining());
+ buf.put_slice(&self.buffer[..max]);
+ self.buffer.advance(max);
+ if max > 0 {
+ // The passed `buf` might have been empty, don't wake up if
+ // no bytes have been moved.
+ if let Some(waker) = self.write_waker.take() {
+ waker.wake();
+ }
+ }
+ Poll::Ready(Ok(()))
+ } else if self.is_closed {
+ Poll::Ready(Ok(()))
+ } else {
+ self.read_waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+}
+
+impl AsyncWrite for Pipe {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<std::io::Result<usize>> {
+ if self.is_closed {
+ return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
+ }
+ let avail = self.max_buf_size - self.buffer.len();
+ if avail == 0 {
+ self.write_waker = Some(cx.waker().clone());
+ return Poll::Pending;
+ }
+
+ let len = buf.len().min(avail);
+ self.buffer.extend_from_slice(&buf[..len]);
+ if let Some(waker) = self.read_waker.take() {
+ waker.wake();
+ }
+ Poll::Ready(Ok(len))
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ _: &mut task::Context<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ self.close();
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs
index c4754ab..e75ea03 100644
--- a/src/io/util/mod.rs
+++ b/src/io/util/mod.rs
@@ -25,7 +25,10 @@ cfg_io_util! {
mod chain;
mod copy;
- pub use copy::{copy, Copy};
+ pub use copy::copy;
+
+ mod copy_buf;
+ pub use copy_buf::copy_buf;
mod empty;
pub use empty::{empty, Empty};
@@ -35,6 +38,9 @@ cfg_io_util! {
mod lines;
pub use lines::Lines;
+ mod mem;
+ pub use mem::{duplex, DuplexStream};
+
mod read;
mod read_buf;
mod read_exact;
@@ -60,11 +66,6 @@ cfg_io_util! {
mod split;
pub use split::Split;
- cfg_stream! {
- mod stream_reader;
- pub use stream_reader::{stream_reader, StreamReader};
- }
-
mod take;
pub use take::Take;
diff --git a/src/io/util/read.rs b/src/io/util/read.rs
index a8ca370..edc9d5a 100644
--- a/src/io/util/read.rs
+++ b/src/io/util/read.rs
@@ -1,7 +1,9 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -15,10 +17,14 @@ pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R>
where
R: AsyncRead + Unpin + ?Sized,
{
- Read { reader, buf }
+ Read {
+ reader,
+ buf,
+ _pin: PhantomPinned,
+ }
}
-cfg_io_util! {
+pin_project! {
/// A future which can be used to easily read available number of bytes to fill
/// a buffer.
///
@@ -28,6 +34,9 @@ cfg_io_util! {
pub struct Read<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -37,19 +46,10 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.reader).poll_read(cx, me.buf)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Read<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
+ let mut buf = ReadBuf::new(*me.buf);
+ ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?;
+ Poll::Ready(Ok(buf.filled().len()))
}
}
diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs
index 6ee3d24..696deef 100644
--- a/src/io/util/read_buf.rs
+++ b/src/io/util/read_buf.rs
@@ -1,8 +1,10 @@
use crate::io::AsyncRead;
use bytes::BufMut;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -11,16 +13,22 @@ where
R: AsyncRead + Unpin,
B: BufMut,
{
- ReadBuf { reader, buf }
+ ReadBuf {
+ reader,
+ buf,
+ _pin: PhantomPinned,
+ }
}
-cfg_io_util! {
+pin_project! {
/// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf).
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadBuf<'a, R, B> {
reader: &'a mut R,
buf: &'a mut B,
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -31,8 +39,34 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ use crate::io::ReadBuf;
+ use std::mem::MaybeUninit;
+
+ let me = self.project();
+
+ if !me.buf.has_remaining_mut() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = {
+ let dst = me.buf.bytes_mut();
+ let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
+ let mut buf = ReadBuf::uninit(dst);
+ let ptr = buf.filled().as_ptr();
+ ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?);
+
+ // Ensure the pointer does not change from under us
+ assert_eq!(ptr, buf.filled().as_ptr());
+ buf.filled().len()
+ };
+
+ // Safety: This is guaranteed to be the number of initialized (and read)
+ // bytes due to the invariants provided by `ReadBuf::filled`.
+ unsafe {
+ me.buf.advance_mut(n);
+ }
+
+ Poll::Ready(Ok(n))
}
}
diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs
index 86b8412..1e8150e 100644
--- a/src/io/util/read_exact.rs
+++ b/src/io/util/read_exact.rs
@@ -1,7 +1,9 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::marker::Unpin;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -17,12 +19,12 @@ where
{
ReadExact {
reader,
- buf,
- pos: 0,
+ buf: ReadBuf::new(buf),
+ _pin: PhantomPinned,
}
}
-cfg_io_util! {
+pin_project! {
/// Creates a future which will read exactly enough bytes to fill `buf`,
/// returning an error if EOF is hit sooner.
///
@@ -31,8 +33,10 @@ cfg_io_util! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadExact<'a, A: ?Sized> {
reader: &'a mut A,
- buf: &'a mut [u8],
- pos: usize,
+ buf: ReadBuf<'a>,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -46,32 +50,20 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let mut me = self.project();
+
loop {
// if our buffer is empty, then we need to read some data to continue.
- if self.pos < self.buf.len() {
- let me = &mut *self;
- let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?;
- me.pos += n;
- if n == 0 {
+ let rem = me.buf.remaining();
+ if rem != 0 {
+ ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?;
+ if me.buf.remaining() == rem {
return Err(eof()).into();
}
- }
-
- if self.pos >= self.buf.len() {
- return Poll::Ready(Ok(self.pos));
+ } else {
+ return Poll::Ready(Ok(me.buf.capacity()));
}
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadExact<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs
index 9d37dc7..5b9fb7b 100644
--- a/src/io/util/read_int.rs
+++ b/src/io/util/read_int.rs
@@ -1,10 +1,11 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use bytes::Buf;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::io::ErrorKind::UnexpectedEof;
+use std::marker::PhantomPinned;
use std::mem::size_of;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -16,11 +17,15 @@ macro_rules! reader {
($name:ident, $ty:ty, $reader:ident, $bytes:expr) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<R> {
#[pin]
src: R,
buf: [u8; $bytes],
read: u8,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -30,6 +35,7 @@ macro_rules! reader {
src,
buf: [0; $bytes],
read: 0,
+ _pin: PhantomPinned,
}
}
}
@@ -48,17 +54,19 @@ macro_rules! reader {
}
while *me.read < $bytes as u8 {
- *me.read += match me
- .src
- .as_mut()
- .poll_read(cx, &mut me.buf[*me.read as usize..])
- {
+ let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]);
+
+ *me.read += match me.src.as_mut().poll_read(cx, &mut buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
- Poll::Ready(Ok(0)) => {
- return Poll::Ready(Err(UnexpectedEof.into()));
+ Poll::Ready(Ok(())) => {
+ let n = buf.filled().len();
+ if n == 0 {
+ return Poll::Ready(Err(UnexpectedEof.into()));
+ }
+
+ n as u8
}
- Poll::Ready(Ok(n)) => n as u8,
};
}
@@ -75,15 +83,22 @@ macro_rules! reader8 {
pin_project! {
/// Future returned from `read_u8`
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<R> {
#[pin]
reader: R,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<R> $name<R> {
pub(crate) fn new(reader: R) -> $name<R> {
- $name { reader }
+ $name {
+ reader,
+ _pin: PhantomPinned,
+ }
}
}
@@ -97,12 +112,17 @@ macro_rules! reader8 {
let me = self.project();
let mut buf = [0; 1];
- match me.reader.poll_read(cx, &mut buf[..]) {
+ let mut buf = ReadBuf::new(&mut buf);
+ match me.reader.poll_read(cx, &mut buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
- Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())),
- Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)),
- Poll::Ready(Ok(_)) => unreachable!(),
+ Poll::Ready(Ok(())) => {
+ if buf.filled().len() == 0 {
+ return Poll::Ready(Err(UnexpectedEof.into()));
+ }
+
+ Poll::Ready(Ok(buf.filled()[0] as $ty))
+ }
}
}
}
diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs
index d625a76..d38ffaf 100644
--- a/src/io/util/read_line.rs
+++ b/src/io/util/read_line.rs
@@ -1,26 +1,32 @@
use crate::io::util::read_until::read_until_internal;
use crate::io::AsyncBufRead;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
+use std::string::FromUtf8Error;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadLine<'a, R: ?Sized> {
reader: &'a mut R,
- /// This is the buffer we were provided. It will be replaced with an empty string
- /// while reading to postpone utf-8 handling until after reading.
+ // This is the buffer we were provided. It will be replaced with an empty string
+ // while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
- /// The actual allocation of the string is moved into a vector instead.
+ // The actual allocation of the string is moved into this vector instead.
buf: Vec<u8>,
- /// The number of bytes appended to buf. This can be less than buf.len() if
- /// the buffer was not empty when the operation was started.
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -33,6 +39,7 @@ where
buf: mem::replace(string, String::new()).into_bytes(),
output: string,
read: 0,
+ _pin: PhantomPinned,
}
}
@@ -42,31 +49,33 @@ fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_re
*output = String::from_utf8(vector).expect("The original data must be valid utf-8.");
}
-pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
- reader: Pin<&mut R>,
- cx: &mut Context<'_>,
+/// This handles the various failure cases and puts the string back into `output`.
+///
+/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line`
+/// disagree on what should happen when an IO error occurs.
+pub(super) fn finish_string_read(
+ io_res: io::Result<usize>,
+ utf8_res: Result<String, FromUtf8Error>,
+ read: usize,
output: &mut String,
- buf: &mut Vec<u8>,
- read: &mut usize,
+ truncate_on_io_error: bool,
) -> Poll<io::Result<usize>> {
- let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
- let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
-
- // At this point both buf and output are empty. The allocation is in utf8_res.
-
- debug_assert!(buf.is_empty());
match (io_res, utf8_res) {
(Ok(num_bytes), Ok(string)) => {
- debug_assert_eq!(*read, 0);
+ debug_assert_eq!(read, 0);
*output = string;
Poll::Ready(Ok(num_bytes))
}
(Err(io_err), Ok(string)) => {
*output = string;
+ if truncate_on_io_error {
+ let original_len = output.len() - read;
+ output.truncate(original_len);
+ }
Poll::Ready(Err(io_err))
}
(Ok(num_bytes), Err(utf8_err)) => {
- debug_assert_eq!(*read, 0);
+ debug_assert_eq!(read, 0);
put_back_original_data(output, utf8_err.into_bytes(), num_bytes);
Poll::Ready(Err(io::Error::new(
@@ -75,35 +84,36 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
)))
}
(Err(io_err), Err(utf8_err)) => {
- put_back_original_data(output, utf8_err.into_bytes(), *read);
+ put_back_original_data(output, utf8_err.into_bytes(), read);
Poll::Ready(Err(io_err))
}
}
}
-impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
- type Output = io::Result<usize>;
+pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
+ reader: Pin<&mut R>,
+ cx: &mut Context<'_>,
+ output: &mut String,
+ buf: &mut Vec<u8>,
+ read: &mut usize,
+) -> Poll<io::Result<usize>> {
+ let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
+ let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- output,
- buf,
- read,
- } = &mut *self;
+ // At this point both buf and output are empty. The allocation is in utf8_res.
- read_line_internal(Pin::new(reader), cx, output, buf, read)
- }
+ debug_assert!(buf.is_empty());
+ debug_assert!(output.is_empty());
+ finish_string_read(io_res, utf8_res, *read, output, false)
}
-#[cfg(test)]
-mod tests {
- use super::*;
+impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
+ type Output = io::Result<usize>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadLine<'_, PhantomPinned>>();
+ read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read)
}
}
diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs
index a2cd99b..a974625 100644
--- a/src/io/util/read_to_end.rs
+++ b/src/io/util/read_to_end.rs
@@ -1,92 +1,105 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
-use std::mem::MaybeUninit;
+use std::marker::PhantomPinned;
+use std::mem::{self, MaybeUninit};
use std::pin::Pin;
use std::task::{Context, Poll};
-#[derive(Debug)]
-#[must_use = "futures do nothing unless you `.await` or poll them"]
-#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
-pub struct ReadToEnd<'a, R: ?Sized> {
- reader: &'a mut R,
- buf: &'a mut Vec<u8>,
- start_len: usize,
+pin_project! {
+ #[derive(Debug)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
+ pub struct ReadToEnd<'a, R: ?Sized> {
+ reader: &'a mut R,
+ buf: &'a mut Vec<u8>,
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
+ read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
+ }
}
-pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, R>
+pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R>
where
R: AsyncRead + Unpin + ?Sized,
{
- let start_len = buf.len();
ReadToEnd {
reader,
- buf,
- start_len,
+ buf: buffer,
+ read: 0,
+ _pin: PhantomPinned,
}
}
-struct Guard<'a> {
- buf: &'a mut Vec<u8>,
- len: usize,
-}
-
-impl Drop for Guard<'_> {
- fn drop(&mut self) {
- unsafe {
- self.buf.set_len(self.len);
+pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
+ buf: &mut Vec<u8>,
+ mut reader: Pin<&mut R>,
+ num_read: &mut usize,
+ cx: &mut Context<'_>,
+) -> Poll<io::Result<usize>> {
+ loop {
+ // safety: The caller promised to prepare the buffer.
+ let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx));
+ match ret {
+ Err(err) => return Poll::Ready(Err(err)),
+ Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))),
+ Ok(num) => {
+ *num_read += num;
+ }
}
}
}
-// This uses an adaptive system to extend the vector when it fills. We want to
-// avoid paying to allocate and zero a huge chunk of memory if the reader only
-// has 4 bytes while still making large reads if the reader does have a ton
-// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every
-// time is 4,500 times (!) slower than this if the reader has a very small
-// amount of data to return.
-//
-// Because we're extending the buffer with uninitialized data for trusted
-// readers, we need to make sure to truncate that if any of this panics.
-pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>(
- mut rd: Pin<&mut R>,
- cx: &mut Context<'_>,
+/// Tries to read from the provided AsyncRead.
+///
+/// The length of the buffer is increased by the number of bytes read.
+fn poll_read_to_end<R: AsyncRead + ?Sized>(
buf: &mut Vec<u8>,
- start_len: usize,
+ read: Pin<&mut R>,
+ cx: &mut Context<'_>,
) -> Poll<io::Result<usize>> {
- let mut g = Guard {
- len: buf.len(),
- buf,
- };
- let ret;
- loop {
- if g.len == g.buf.len() {
- unsafe {
- g.buf.reserve(32);
- let capacity = g.buf.capacity();
- g.buf.set_len(capacity);
+ // This uses an adaptive system to extend the vector when it fills. We want to
+ // avoid paying to allocate and zero a huge chunk of memory if the reader only
+ // has 4 bytes while still making large reads if the reader does have a ton
+ // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every
+ // time is 4,500 times (!) slower than this if the reader has a very small
+ // amount of data to return.
+ reserve(buf, 32);
- let b = &mut *(&mut g.buf[g.len..] as *mut [u8] as *mut [MaybeUninit<u8>]);
+ let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf));
- rd.prepare_uninitialized_buffer(b);
- }
- }
+ ready!(read.poll_read(cx, &mut unused_capacity))?;
- match ready!(rd.as_mut().poll_read(cx, &mut g.buf[g.len..])) {
- Ok(0) => {
- ret = Poll::Ready(Ok(g.len - start_len));
- break;
- }
- Ok(n) => g.len += n,
- Err(e) => {
- ret = Poll::Ready(Err(e));
- break;
- }
- }
+ let n = unused_capacity.filled().len();
+ let new_len = buf.len() + n;
+
+ // This should no longer even be possible in safe Rust. An implementor
+ // would need to have unsafely *replaced* the buffer inside `ReadBuf`,
+ // which... yolo?
+ assert!(new_len <= buf.capacity());
+ unsafe {
+ buf.set_len(new_len);
}
+ Poll::Ready(Ok(n))
+}
- ret
+/// Allocates more memory and ensures that the unused capacity is prepared for use
+/// with the `AsyncRead`.
+fn reserve(buf: &mut Vec<u8>, bytes: usize) {
+ if buf.capacity() - buf.len() >= bytes {
+ return;
+ }
+ buf.reserve(bytes);
+}
+
+/// Returns the unused capacity of the provided vector.
+fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] {
+ let uninit = bytes::BufMut::bytes_mut(buf);
+ unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) }
}
impl<A> Future for ReadToEnd<'_, A>
@@ -95,19 +108,9 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let this = &mut *self;
- read_to_end_internal(Pin::new(&mut this.reader), cx, this.buf, this.start_len)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadToEnd<'_, PhantomPinned>>();
+ read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx)
}
}
diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs
index cab0505..e463203 100644
--- a/src/io/util/read_to_string.rs
+++ b/src/io/util/read_to_string.rs
@@ -1,58 +1,71 @@
+use crate::io::util::read_line::finish_string_read;
use crate::io::util::read_to_end::read_to_end_internal;
use crate::io::AsyncRead;
+use pin_project_lite::pin_project;
use std::future::Future;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadToString<'a, R: ?Sized> {
reader: &'a mut R,
- buf: &'a mut String,
- bytes: Vec<u8>,
- start_len: usize,
+ // This is the buffer we were provided. It will be replaced with an empty string
+ // while reading to postpone utf-8 handling until after reading.
+ output: &'a mut String,
+ // The actual allocation of the string is moved into this vector instead.
+ buf: Vec<u8>,
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
+ read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
-pub(crate) fn read_to_string<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadToString<'a, R>
+pub(crate) fn read_to_string<'a, R>(
+ reader: &'a mut R,
+ string: &'a mut String,
+) -> ReadToString<'a, R>
where
R: AsyncRead + ?Sized + Unpin,
{
- let start_len = buf.len();
+ let buf = mem::replace(string, String::new()).into_bytes();
ReadToString {
reader,
- bytes: mem::replace(buf, String::new()).into_bytes(),
buf,
- start_len,
+ output: string,
+ read: 0,
+ _pin: PhantomPinned,
}
}
-fn read_to_string_internal<R: AsyncRead + ?Sized>(
+/// # Safety
+///
+/// Before first calling this method, the unused capacity must have been
+/// prepared for use with the provided AsyncRead. This can be done using the
+/// `prepare_buffer` function in `read_to_end.rs`.
+unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>(
reader: Pin<&mut R>,
+ output: &mut String,
+ buf: &mut Vec<u8>,
+ read: &mut usize,
cx: &mut Context<'_>,
- buf: &mut String,
- bytes: &mut Vec<u8>,
- start_len: usize,
) -> Poll<io::Result<usize>> {
- let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?;
- match String::from_utf8(mem::replace(bytes, Vec::new())) {
- Ok(string) => {
- debug_assert!(buf.is_empty());
- *buf = string;
- Poll::Ready(Ok(ret))
- }
- Err(e) => {
- *bytes = e.into_bytes();
- Poll::Ready(Err(io::Error::new(
- io::ErrorKind::InvalidData,
- "stream did not contain valid UTF-8",
- )))
- }
- }
+ let io_res = ready!(read_to_end_internal(buf, reader, read, cx));
+ let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
+
+ // At this point both buf and output are empty. The allocation is in utf8_res.
+
+ debug_assert!(buf.is_empty());
+ debug_assert!(output.is_empty());
+ finish_string_read(io_res, utf8_res, *read, output, true)
}
impl<A> Future for ReadToString<'_, A>
@@ -61,31 +74,10 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- buf,
- bytes,
- start_len,
- } = &mut *self;
- let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len);
- if let Poll::Ready(Err(_)) = ret {
- // Put back the original string.
- bytes.truncate(*start_len);
- **buf = String::from_utf8(mem::replace(bytes, Vec::new()))
- .expect("original string no longer utf-8");
- }
- ret
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadToString<'_, PhantomPinned>>();
+ // safety: The constructor of ReadToString called `prepare_buffer`.
+ unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) }
}
}
diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs
index 78dac8c..3599cff 100644
--- a/src/io/util/read_until.rs
+++ b/src/io/util/read_until.rs
@@ -1,12 +1,14 @@
use crate::io::AsyncBufRead;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method.
/// The delimeter is included in the resulting vector.
#[derive(Debug)]
@@ -15,9 +17,12 @@ cfg_io_util! {
reader: &'a mut R,
delimeter: u8,
buf: &'a mut Vec<u8>,
- /// The number of bytes appended to buf. This can be less than buf.len() if
- /// the buffer was not empty when the operation was started.
+ // The number of bytes appended to buf. This can be less than buf.len() if
+ // the buffer was not empty when the operation was started.
read: usize,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -34,6 +39,7 @@ where
delimeter,
buf,
read: 0,
+ _pin: PhantomPinned,
}
}
@@ -66,24 +72,8 @@ pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>(
impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> {
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let Self {
- reader,
- delimeter,
- buf,
- read,
- } = &mut *self;
- read_until_internal(Pin::new(reader), cx, *delimeter, buf, read)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<ReadUntil<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read)
}
}
diff --git a/src/io/util/repeat.rs b/src/io/util/repeat.rs
index eeef7cc..1142765 100644
--- a/src/io/util/repeat.rs
+++ b/src/io/util/repeat.rs
@@ -1,4 +1,4 @@
-use crate::io::AsyncRead;
+use crate::io::{AsyncRead, ReadBuf};
use std::io;
use std::pin::Pin;
@@ -47,19 +47,17 @@ cfg_io_util! {
}
impl AsyncRead for Repeat {
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
- false
- }
#[inline]
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- for byte in &mut *buf {
- *byte = self.byte;
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ // TODO: could be faster, but should we unsafe it?
+ while buf.remaining() != 0 {
+ buf.put_slice(&[self.byte]);
}
- Poll::Ready(Ok(buf.len()))
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/shutdown.rs b/src/io/util/shutdown.rs
index 33ac0ac..6d30b00 100644
--- a/src/io/util/shutdown.rs
+++ b/src/io/util/shutdown.rs
@@ -1,18 +1,24 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future used to shutdown an I/O object.
///
/// Created by the [`AsyncWriteExt::shutdown`][shutdown] function.
/// [shutdown]: crate::io::AsyncWriteExt::shutdown
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Shutdown<'a, A: ?Sized> {
a: &'a mut A,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +27,10 @@ pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A>
where
A: AsyncWrite + Unpin + ?Sized,
{
- Shutdown { a }
+ Shutdown {
+ a,
+ _pin: PhantomPinned,
+ }
}
impl<A> Future for Shutdown<'_, A>
@@ -30,19 +39,8 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- let me = &mut *self;
- Pin::new(&mut *me.a).poll_shutdown(cx)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<Shutdown<'_, PhantomPinned>>();
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let me = self.project();
+ Pin::new(me.a).poll_shutdown(cx)
}
}
diff --git a/src/io/util/split.rs b/src/io/util/split.rs
index f552ed5..492e26a 100644
--- a/src/io/util/split.rs
+++ b/src/io/util/split.rs
@@ -65,8 +65,7 @@ impl<R> Split<R>
where
R: AsyncBufRead,
{
- #[doc(hidden)]
- pub fn poll_next_segment(
+ fn poll_next_segment(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<Vec<u8>>>> {
diff --git a/src/io/util/stream_reader.rs b/src/io/util/stream_reader.rs
deleted file mode 100644
index b98f8bd..0000000
--- a/src/io/util/stream_reader.rs
+++ /dev/null
@@ -1,184 +0,0 @@
-use crate::io::{AsyncBufRead, AsyncRead};
-use crate::stream::Stream;
-use bytes::{Buf, BufMut};
-use pin_project_lite::pin_project;
-use std::io;
-use std::mem::MaybeUninit;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-pin_project! {
- /// Convert a stream of byte chunks into an [`AsyncRead`].
- ///
- /// This type is usually created using the [`stream_reader`] function.
- ///
- /// [`AsyncRead`]: crate::io::AsyncRead
- /// [`stream_reader`]: crate::io::stream_reader
- #[derive(Debug)]
- #[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
- #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
- pub struct StreamReader<S, B> {
- #[pin]
- inner: S,
- chunk: Option<B>,
- }
-}
-
-/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead).
-///
-/// # Example
-///
-/// ```
-/// use bytes::Bytes;
-/// use tokio::io::{stream_reader, AsyncReadExt};
-/// # #[tokio::main]
-/// # async fn main() -> std::io::Result<()> {
-///
-/// // Create a stream from an iterator.
-/// let stream = tokio::stream::iter(vec![
-/// Ok(Bytes::from_static(&[0, 1, 2, 3])),
-/// Ok(Bytes::from_static(&[4, 5, 6, 7])),
-/// Ok(Bytes::from_static(&[8, 9, 10, 11])),
-/// ]);
-///
-/// // Convert it to an AsyncRead.
-/// let mut read = stream_reader(stream);
-///
-/// // Read five bytes from the stream.
-/// let mut buf = [0; 5];
-/// read.read_exact(&mut buf).await?;
-/// assert_eq!(buf, [0, 1, 2, 3, 4]);
-///
-/// // Read the rest of the current chunk.
-/// assert_eq!(read.read(&mut buf).await?, 3);
-/// assert_eq!(&buf[..3], [5, 6, 7]);
-///
-/// // Read the next chunk.
-/// assert_eq!(read.read(&mut buf).await?, 4);
-/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
-///
-/// // We have now reached the end.
-/// assert_eq!(read.read(&mut buf).await?, 0);
-///
-/// # Ok(())
-/// # }
-/// ```
-#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
-#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
-pub fn stream_reader<S, B>(stream: S) -> StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- StreamReader::new(stream)
-}
-
-impl<S, B> StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- /// Convert the provided stream into an `AsyncRead`.
- fn new(stream: S) -> Self {
- Self {
- inner: stream,
- chunk: None,
- }
- }
- /// Do we have a chunk and is it non-empty?
- fn has_chunk(self: Pin<&mut Self>) -> bool {
- if let Some(chunk) = self.project().chunk {
- chunk.remaining() > 0
- } else {
- false
- }
- }
-}
-
-impl<S, B> AsyncRead for StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<io::Result<usize>> {
- if buf.is_empty() {
- return Poll::Ready(Ok(0));
- }
-
- let inner_buf = match self.as_mut().poll_fill_buf(cx) {
- Poll::Ready(Ok(buf)) => buf,
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => return Poll::Pending,
- };
- let len = std::cmp::min(inner_buf.len(), buf.len());
- (&mut buf[..len]).copy_from_slice(&inner_buf[..len]);
-
- self.consume(len);
- Poll::Ready(Ok(len))
- }
- fn poll_read_buf<BM: BufMut>(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut BM,
- ) -> Poll<io::Result<usize>>
- where
- Self: Sized,
- {
- if !buf.has_remaining_mut() {
- return Poll::Ready(Ok(0));
- }
-
- let inner_buf = match self.as_mut().poll_fill_buf(cx) {
- Poll::Ready(Ok(buf)) => buf,
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => return Poll::Pending,
- };
- let len = std::cmp::min(inner_buf.len(), buf.remaining_mut());
- buf.put_slice(&inner_buf[..len]);
-
- self.consume(len);
- Poll::Ready(Ok(len))
- }
- unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool {
- false
- }
-}
-
-impl<S, B> AsyncBufRead for StreamReader<S, B>
-where
- S: Stream<Item = Result<B, io::Error>>,
- B: Buf,
-{
- fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
- loop {
- if self.as_mut().has_chunk() {
- // This unwrap is very sad, but it can't be avoided.
- let buf = self.project().chunk.as_ref().unwrap().bytes();
- return Poll::Ready(Ok(buf));
- } else {
- match self.as_mut().project().inner.poll_next(cx) {
- Poll::Ready(Some(Ok(chunk))) => {
- // Go around the loop in case the chunk is empty.
- *self.as_mut().project().chunk = Some(chunk);
- }
- Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)),
- Poll::Ready(None) => return Poll::Ready(Ok(&[])),
- Poll::Pending => return Poll::Pending,
- }
- }
- }
- }
- fn consume(self: Pin<&mut Self>, amt: usize) {
- if amt > 0 {
- self.project()
- .chunk
- .as_mut()
- .expect("No chunk present")
- .advance(amt);
- }
- }
-}
diff --git a/src/io/util/take.rs b/src/io/util/take.rs
index 5d6bd90..b5e90c9 100644
--- a/src/io/util/take.rs
+++ b/src/io/util/take.rs
@@ -1,7 +1,6 @@
-use crate::io::{AsyncBufRead, AsyncRead};
+use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use pin_project_lite::pin_project;
-use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{cmp, io};
@@ -76,24 +75,27 @@ impl<R: AsyncRead> Take<R> {
}
impl<R: AsyncRead> AsyncRead for Take<R> {
- unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
- self.inner.prepare_uninitialized_buffer(buf)
- }
-
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
- buf: &mut [u8],
- ) -> Poll<Result<usize, io::Error>> {
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<Result<(), io::Error>> {
if self.limit_ == 0 {
- return Poll::Ready(Ok(0));
+ return Poll::Ready(Ok(()));
}
let me = self.project();
- let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize;
- let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?;
+ let mut b = buf.take(*me.limit_ as usize);
+ ready!(me.inner.poll_read(cx, &mut b))?;
+ let n = b.filled().len();
+
+ // We need to update the original ReadBuf
+ unsafe {
+ buf.assume_init(n);
+ }
+ buf.advance(n);
*me.limit_ -= n as u64;
- Poll::Ready(Ok(n))
+ Poll::Ready(Ok(()))
}
}
diff --git a/src/io/util/write.rs b/src/io/util/write.rs
index 433a421..92169eb 100644
--- a/src/io/util/write.rs
+++ b/src/io/util/write.rs
@@ -1,17 +1,22 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Write<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -21,7 +26,11 @@ pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W>
where
W: AsyncWrite + Unpin + ?Sized,
{
- Write { writer, buf }
+ Write {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W> Future for Write<'_, W>
@@ -30,8 +39,8 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
Pin::new(&mut *me.writer).poll_write(cx, me.buf)
}
}
diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs
index 898006c..e59d41e 100644
--- a/src/io/util/write_all.rs
+++ b/src/io/util/write_all.rs
@@ -1,17 +1,22 @@
use crate::io::AsyncWrite;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteAll<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -19,7 +24,11 @@ pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a,
where
W: AsyncWrite + Unpin + ?Sized,
{
- WriteAll { writer, buf }
+ WriteAll {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W> Future for WriteAll<'_, W>
@@ -28,13 +37,13 @@ where
{
type Output = io::Result<()>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
- let me = &mut *self;
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ let me = self.project();
while !me.buf.is_empty() {
- let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?;
+ let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?;
{
- let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n);
- me.buf = rest;
+ let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n);
+ *me.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
@@ -44,14 +53,3 @@ where
Poll::Ready(Ok(()))
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn assert_unpin() {
- use std::marker::PhantomPinned;
- crate::is_unpin::<WriteAll<'_, PhantomPinned>>();
- }
-}
diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs
index cedfde6..1310e5c 100644
--- a/src/io/util/write_buf.rs
+++ b/src/io/util/write_buf.rs
@@ -1,18 +1,22 @@
use crate::io::AsyncWrite;
use bytes::Buf;
+use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
-cfg_io_util! {
+pin_project! {
/// A future to write some of the buffer to an `AsyncWrite`.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteBuf<'a, W, B> {
writer: &'a mut W,
buf: &'a mut B,
+ #[pin]
+ _pin: PhantomPinned,
}
}
@@ -23,7 +27,11 @@ where
W: AsyncWrite + Unpin,
B: Buf,
{
- WriteBuf { writer, buf }
+ WriteBuf {
+ writer,
+ buf,
+ _pin: PhantomPinned,
+ }
}
impl<W, B> Future for WriteBuf<'_, W, B>
@@ -33,8 +41,15 @@ where
{
type Output = io::Result<usize>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
- let me = &mut *self;
- Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf)
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
+ let me = self.project();
+
+ if !me.buf.has_remaining() {
+ return Poll::Ready(Ok(0));
+ }
+
+ let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?;
+ me.buf.advance(n);
+ Poll::Ready(Ok(n))
}
}
diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs
index ee992de..13bc191 100644
--- a/src/io/util/write_int.rs
+++ b/src/io/util/write_int.rs
@@ -4,6 +4,7 @@ use bytes::BufMut;
use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
+use std::marker::PhantomPinned;
use std::mem::size_of;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -15,20 +16,25 @@ macro_rules! writer {
($name:ident, $ty:ty, $writer:ident, $bytes:expr) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<W> {
#[pin]
dst: W,
buf: [u8; $bytes],
written: u8,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<W> $name<W> {
pub(crate) fn new(w: W, value: $ty) -> Self {
- let mut writer = $name {
+ let mut writer = Self {
buf: [0; $bytes],
written: 0,
dst: w,
+ _pin: PhantomPinned,
};
BufMut::$writer(&mut &mut writer.buf[..], value);
writer
@@ -72,16 +78,24 @@ macro_rules! writer8 {
($name:ident, $ty:ty) => {
pin_project! {
#[doc(hidden)]
+ #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct $name<W> {
#[pin]
dst: W,
byte: $ty,
+ // Make this future `!Unpin` for compatibility with async trait methods.
+ #[pin]
+ _pin: PhantomPinned,
}
}
impl<W> $name<W> {
pub(crate) fn new(dst: W, byte: $ty) -> Self {
- Self { dst, byte }
+ Self {
+ dst,
+ byte,
+ _pin: PhantomPinned,
+ }
}
}