diff options
Diffstat (limited to 'src/io')
49 files changed, 2374 insertions, 1728 deletions
diff --git a/src/io/async_read.rs b/src/io/async_read.rs index 1aef415..d075443 100644 --- a/src/io/async_read.rs +++ b/src/io/async_read.rs @@ -1,6 +1,5 @@ -use bytes::BufMut; +use super::ReadBuf; use std::io; -use std::mem::MaybeUninit; use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -16,9 +15,10 @@ use std::task::{Context, Poll}; /// Specifically, this means that the `poll_read` function will return one of /// the following: /// -/// * `Poll::Ready(Ok(n))` means that `n` bytes of data was immediately read -/// and placed into the output buffer, where `n` == 0 implies that EOF has -/// been reached. +/// * `Poll::Ready(Ok(()))` means that data was immediately read and placed into +/// the output buffer. The amount of data read can be determined by the +/// increase in the length of the slice returned by `ReadBuf::filled`. If the +/// difference is 0, EOF has been reached. /// /// * `Poll::Pending` means that no data was read into the buffer /// provided. The I/O object is not currently readable but may become readable @@ -41,110 +41,29 @@ use std::task::{Context, Poll}; /// [`Read::read`]: std::io::Read::read /// [`AsyncReadExt`]: crate::io::AsyncReadExt pub trait AsyncRead { - /// Prepares an uninitialized buffer to be safe to pass to `read`. Returns - /// `true` if the supplied buffer was zeroed out. - /// - /// While it would be highly unusual, implementations of [`io::Read`] are - /// able to read data from the buffer passed as an argument. Because of - /// this, the buffer passed to [`io::Read`] must be initialized memory. In - /// situations where large numbers of buffers are used, constantly having to - /// zero out buffers can be expensive. - /// - /// This function does any necessary work to prepare an uninitialized buffer - /// to be safe to pass to `read`. If `read` guarantees to never attempt to - /// read data out of the supplied buffer, then `prepare_uninitialized_buffer` - /// doesn't need to do any work. - /// - /// If this function returns `true`, then the memory has been zeroed out. - /// This allows implementations of `AsyncRead` which are composed of - /// multiple subimplementations to efficiently implement - /// `prepare_uninitialized_buffer`. - /// - /// This function isn't actually `unsafe` to call but `unsafe` to implement. - /// The implementer must ensure that either the whole `buf` has been zeroed - /// or `poll_read_buf()` overwrites the buffer without reading it and returns - /// correct value. - /// - /// This function is called from [`poll_read_buf`]. - /// - /// # Safety - /// - /// Implementations that return `false` must never read from data slices - /// that they did not write to. - /// - /// [`io::Read`]: std::io::Read - /// [`poll_read_buf`]: method@Self::poll_read_buf - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - for x in buf { - *x = MaybeUninit::new(0); - } - - true - } - /// Attempts to read from the `AsyncRead` into `buf`. /// - /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + /// On success, returns `Poll::Ready(Ok(()))` and fills `buf` with data + /// read. If no data was read (`buf.filled().is_empty()`) it implies that + /// EOF has been reached. /// - /// If no data is available for reading, the method returns - /// `Poll::Pending` and arranges for the current task (via - /// `cx.waker()`) to receive a notification when the object becomes - /// readable or is closed. + /// If no data is available for reading, the method returns `Poll::Pending` + /// and arranges for the current task (via `cx.waker()`) to receive a + /// notification when the object becomes readable or is closed. fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>>; - - /// Pulls some bytes from this source into the specified `BufMut`, returning - /// how many bytes were read. - /// - /// The `buf` provided will have bytes read into it and the internal cursor - /// will be advanced if any bytes were read. Note that this method typically - /// will not reallocate the buffer provided. - fn poll_read_buf<B: BufMut>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> - where - Self: Sized, - { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - unsafe { - let n = { - let b = buf.bytes_mut(); - - self.prepare_uninitialized_buffer(b); - - // Convert to `&mut [u8]` - let b = &mut *(b as *mut [MaybeUninit<u8>] as *mut [u8]); - - let n = ready!(self.poll_read(cx, b))?; - assert!(n <= b.len(), "Bad AsyncRead implementation, more bytes were reported as read than the buffer can hold"); - n - }; - - buf.advance_mut(n); - Poll::Ready(Ok(n)) - } - } + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>>; } macro_rules! deref_async_read { () => { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut **self).poll_read(cx, buf) } }; @@ -163,43 +82,50 @@ where P: DerefMut + Unpin, P::Target: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - (**self).prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.get_mut().as_mut().poll_read(cx, buf) } } impl AsyncRead for &[u8] { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let amt = std::cmp::min(self.len(), buf.remaining()); + let (a, b) = self.split_at(amt); + buf.put_slice(a); + *self = b; + Poll::Ready(Ok(())) } } impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T> { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } - fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(io::Read::read(self.get_mut(), buf)) + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let pos = self.position(); + let slice: &[u8] = (*self).get_ref().as_ref(); + + // The position could technically be out of bounds, so don't panic... + if pos > slice.len() as u64 { + return Poll::Ready(Ok(())); + } + + let start = pos as usize; + let amt = std::cmp::min(slice.len() - start, buf.remaining()); + // Add won't overflow because of pos check above. + let end = start + amt; + buf.put_slice(&slice[start..end]); + self.set_position(end as u64); + + Poll::Ready(Ok(())) } } diff --git a/src/io/async_seek.rs b/src/io/async_seek.rs index 32ed0a2..bd7a992 100644 --- a/src/io/async_seek.rs +++ b/src/io/async_seek.rs @@ -23,36 +23,33 @@ pub trait AsyncSeek { /// /// If this function returns successfully, then the job has been submitted. /// To find out when it completes, call `poll_complete`. - fn start_seek( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - position: SeekFrom, - ) -> Poll<io::Result<()>>; + /// + /// # Errors + /// + /// This function can return [`io::ErrorKind::Other`] in case there is + /// another seek in progress. To avoid this, it is advisable that any call + /// to `start_seek` is preceded by a call to `poll_complete` to ensure all + /// pending seeks have completed. + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()>; /// Waits for a seek operation to complete. /// /// If the seek operation completed successfully, /// this method returns the new position from the start of the stream. - /// That position can be used later with [`SeekFrom::Start`]. + /// That position can be used later with [`SeekFrom::Start`]. Repeatedly + /// calling this function without calling `start_seek` might return the + /// same result. /// /// # Errors /// /// Seeking to a negative offset is considered an error. - /// - /// # Panics - /// - /// Calling this method without calling `start_seek` first is an error. fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>; } macro_rules! deref_async_seek { () => { - fn start_seek( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - Pin::new(&mut **self).start_seek(cx, pos) + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + Pin::new(&mut **self).start_seek(pos) } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { @@ -74,12 +71,8 @@ where P: DerefMut + Unpin, P::Target: AsyncSeek, { - fn start_seek( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - self.get_mut().as_mut().start_seek(cx, pos) + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.get_mut().as_mut().start_seek(pos) } fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { @@ -88,12 +81,8 @@ where } impl<T: AsRef<[u8]> + Unpin> AsyncSeek for io::Cursor<T> { - fn start_seek( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - pos: SeekFrom, - ) -> Poll<io::Result<()>> { - Poll::Ready(io::Seek::seek(&mut *self, pos).map(drop)) + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + io::Seek::seek(&mut *self, pos).map(drop) } fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> { Poll::Ready(Ok(self.get_mut().position())) diff --git a/src/io/async_write.rs b/src/io/async_write.rs index ecf7575..66ba4bf 100644 --- a/src/io/async_write.rs +++ b/src/io/async_write.rs @@ -1,4 +1,3 @@ -use bytes::Buf; use std::io; use std::ops::DerefMut; use std::pin::Pin; @@ -128,27 +127,6 @@ pub trait AsyncWrite { /// This function will panic if not called within the context of a future's /// task. fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>; - - /// Writes a `Buf` into this value, returning how many bytes were written. - /// - /// Note that this method will advance the `buf` provided automatically by - /// the number of bytes written. - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<Result<usize, io::Error>> - where - Self: Sized, - { - if !buf.has_remaining() { - return Poll::Ready(Ok(0)); - } - - let n = ready!(self.poll_write(cx, buf.bytes()))?; - buf.advance(n); - Poll::Ready(Ok(n)) - } } macro_rules! deref_async_write { diff --git a/src/io/blocking.rs b/src/io/blocking.rs index 2491039..430801e 100644 --- a/src/io/blocking.rs +++ b/src/io/blocking.rs @@ -1,5 +1,5 @@ use crate::io::sys; -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::cmp; use std::future::Future; @@ -53,17 +53,17 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - dst: &mut [u8], - ) -> Poll<io::Result<usize>> { + dst: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { loop { match self.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); if !buf.is_empty() { - let n = buf.copy_to(dst); + buf.copy_to(dst); *buf_cell = Some(buf); - return Ready(Ok(n)); + return Ready(Ok(())); } buf.ensure_capacity_for(dst); @@ -80,9 +80,9 @@ where match res { Ok(_) => { - let n = buf.copy_to(dst); + buf.copy_to(dst); self.state = Idle(Some(buf)); - return Ready(Ok(n)); + return Ready(Ok(())); } Err(e) => { assert!(buf.is_empty()); @@ -203,9 +203,9 @@ impl Buf { self.buf.len() - self.pos } - pub(crate) fn copy_to(&mut self, dst: &mut [u8]) -> usize { - let n = cmp::min(self.len(), dst.len()); - dst[..n].copy_from_slice(&self.bytes()[..n]); + pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize { + let n = cmp::min(self.len(), dst.remaining()); + dst.put_slice(&self.bytes()[..n]); self.pos += n; if self.pos == self.buf.len() { @@ -229,10 +229,10 @@ impl Buf { &self.buf[self.pos..] } - pub(crate) fn ensure_capacity_for(&mut self, bytes: &[u8]) { + pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) { assert!(self.is_empty()); - let len = cmp::min(bytes.len(), MAX_BUF); + let len = cmp::min(bytes.remaining(), MAX_BUF); if self.buf.len() < len { self.buf.reserve(len - self.buf.len()); diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs index dbfb6e1..cd82b26 100644 --- a/src/io/driver/mod.rs +++ b/src/io/driver/mod.rs @@ -1,30 +1,38 @@ -pub(crate) mod platform; +#![cfg_attr(not(feature = "rt"), allow(dead_code))] + +mod ready; +use ready::Ready; mod scheduled_io; pub(crate) use scheduled_io::ScheduledIo; // pub(crate) for tests -use crate::loom::sync::atomic::AtomicUsize; use crate::park::{Park, Unpark}; -use crate::runtime::context; -use crate::util::slab::{Address, Slab}; +use crate::util::bit; +use crate::util::slab::{self, Slab}; -use mio::event::Evented; use std::fmt; use std::io; -use std::sync::atomic::Ordering::SeqCst; use std::sync::{Arc, Weak}; -use std::task::Waker; use std::time::Duration; /// I/O driver, backed by Mio pub(crate) struct Driver { + /// Tracks the number of times `turn` is called. It is safe for this to wrap + /// as it is mostly used to determine when to call `compact()` + tick: u8, + /// Reuse the `mio::Events` value across calls to poll. - events: mio::Events, + events: Option<mio::Events>, + + /// Primary slab handle containing the state for each resource registered + /// with this driver. + resources: Slab<ScheduledIo>, + + /// The system event queue + poll: mio::Poll, /// State shared between the reactor and the handles. inner: Arc<Inner>, - - _wakeup_registration: mio::Registration, } /// A reference to an I/O driver @@ -33,18 +41,20 @@ pub(crate) struct Handle { inner: Weak<Inner>, } -pub(super) struct Inner { - /// The underlying system event queue. - io: mio::Poll, +pub(crate) struct ReadyEvent { + tick: u8, + ready: Ready, +} - /// Dispatch slabs for I/O and futures events - pub(super) io_dispatch: Slab<ScheduledIo>, +pub(super) struct Inner { + /// Registers I/O resources + registry: mio::Registry, - /// The number of sources in `io_dispatch`. - n_sources: AtomicUsize, + /// Allocates `ScheduledIo` handles when creating new resources. + pub(super) io_dispatch: slab::Allocator<ScheduledIo>, /// Used to wake up the reactor from a call to `turn` - wakeup: mio::SetReadiness, + waker: mio::Waker, } #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -53,7 +63,24 @@ pub(super) enum Direction { Write, } -const TOKEN_WAKEUP: mio::Token = mio::Token(Address::NULL); +enum Tick { + Set(u8), + Clear(u8), +} + +// TODO: Don't use a fake token. Instead, reserve a slot entry for the wakeup +// token. +const TOKEN_WAKEUP: mio::Token = mio::Token(1 << 31); + +const ADDRESS: bit::Pack = bit::Pack::least_significant(24); + +// Packs the generation value in the `readiness` field. +// +// The generation prevents a race condition where a slab slot is reused for a +// new socket while the I/O driver is about to apply a readiness event. The +// generaton value is checked when setting new readiness. If the generation do +// not match, then the readiness event is discarded. +const GENERATION: bit::Pack = ADDRESS.then(7); fn _assert_kinds() { fn _assert<T: Send + Sync>() {} @@ -67,24 +94,22 @@ impl Driver { /// Creates a new event loop, returning any error that happened during the /// creation. pub(crate) fn new() -> io::Result<Driver> { - let io = mio::Poll::new()?; - let wakeup_pair = mio::Registration::new2(); + let poll = mio::Poll::new()?; + let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?; + let registry = poll.registry().try_clone()?; - io.register( - &wakeup_pair.0, - TOKEN_WAKEUP, - mio::Ready::readable(), - mio::PollOpt::level(), - )?; + let slab = Slab::new(); + let allocator = slab.allocator(); Ok(Driver { - events: mio::Events::with_capacity(1024), - _wakeup_registration: wakeup_pair.0, + tick: 0, + events: Some(mio::Events::with_capacity(1024)), + resources: slab, + poll, inner: Arc::new(Inner { - io, - io_dispatch: Slab::new(), - n_sources: AtomicUsize::new(0), - wakeup: wakeup_pair.1, + registry, + io_dispatch: allocator, + waker, }), }) } @@ -102,65 +127,66 @@ impl Driver { } fn turn(&mut self, max_wait: Option<Duration>) -> io::Result<()> { + // How often to call `compact()` on the resource slab + const COMPACT_INTERVAL: u8 = 255; + + self.tick = self.tick.wrapping_add(1); + + if self.tick == COMPACT_INTERVAL { + self.resources.compact(); + } + + let mut events = self.events.take().expect("i/o driver event store missing"); + // Block waiting for an event to happen, peeling out how many events // happened. - match self.inner.io.poll(&mut self.events, max_wait) { + match self.poll.poll(&mut events, max_wait) { Ok(_) => {} + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } // Process all the events that came in, dispatching appropriately - - for event in self.events.iter() { + for event in events.iter() { let token = event.token(); - if token == TOKEN_WAKEUP { - self.inner - .wakeup - .set_readiness(mio::Ready::empty()) - .unwrap(); - } else { - self.dispatch(token, event.readiness()); + if token != TOKEN_WAKEUP { + self.dispatch(token, Ready::from_mio(event)); } } + self.events = Some(events); + Ok(()) } - fn dispatch(&self, token: mio::Token, ready: mio::Ready) { - let mut rd = None; - let mut wr = None; + fn dispatch(&mut self, token: mio::Token, ready: Ready) { + let addr = slab::Address::from_usize(ADDRESS.unpack(token.0)); - let address = Address::from_usize(token.0); - - let io = match self.inner.io_dispatch.get(address) { + let io = match self.resources.get(addr) { Some(io) => io, None => return, }; - if io - .set_readiness(address, |curr| curr | ready.as_usize()) - .is_err() - { + let res = io.set_readiness(Some(token.0), Tick::Set(self.tick), |curr| curr | ready); + + if res.is_err() { // token no longer valid! return; } - if ready.is_writable() || platform::is_hup(ready) || platform::is_error(ready) { - wr = io.writer.take_waker(); - } - - if !(ready & (!mio::Ready::writable())).is_empty() { - rd = io.reader.take_waker(); - } - - if let Some(w) = rd { - w.wake(); - } + io.wake(ready); + } +} - if let Some(w) = wr { - w.wake(); - } +impl Drop for Driver { + fn drop(&mut self) { + self.resources.for_each(|io| { + // If a task is waiting on the I/O resource, notify it. The task + // will then attempt to use the I/O resource and fail due to the + // driver being shutdown. + io.wake(Ready::ALL); + }) } } @@ -181,6 +207,8 @@ impl Park for Driver { self.turn(Some(duration))?; Ok(()) } + + fn shutdown(&mut self) {} } impl fmt::Debug for Driver { @@ -191,17 +219,36 @@ impl fmt::Debug for Driver { // ===== impl Handle ===== -impl Handle { - /// Returns a handle to the current reactor - /// - /// # Panics - /// - /// This function panics if there is no current reactor set. - pub(super) fn current() -> Self { - context::io_handle() - .expect("there is no reactor running, must be called from the context of Tokio runtime") +cfg_rt! { + impl Handle { + /// Returns a handle to the current reactor + /// + /// # Panics + /// + /// This function panics if there is no current reactor set and `rt` feature + /// flag is not enabled. + pub(super) fn current() -> Self { + crate::runtime::context::io_handle() + .expect("there is no reactor running, must be called from the context of Tokio runtime") + } + } +} + +cfg_not_rt! { + impl Handle { + /// Returns a handle to the current reactor + /// + /// # Panics + /// + /// This function panics if there is no current reactor set, or if the `rt` + /// feature flag is not enabled. + pub(super) fn current() -> Self { + panic!("there is no reactor running, must be called from the context of Tokio runtime with `rt` enabled.") + } } +} +impl Handle { /// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise /// makes the next call to `turn` return immediately. /// @@ -213,7 +260,7 @@ impl Handle { /// return immediately. fn wakeup(&self) { if let Some(inner) = self.inner() { - inner.wakeup.set_readiness(mio::Ready::readable()).unwrap(); + inner.waker.wake().expect("failed to wake I/O driver"); } } @@ -242,159 +289,35 @@ impl Inner { /// The registration token is returned. pub(super) fn add_source( &self, - source: &dyn Evented, - ready: mio::Ready, - ) -> io::Result<Address> { - let address = self.io_dispatch.alloc().ok_or_else(|| { + source: &mut impl mio::event::Source, + interest: mio::Interest, + ) -> io::Result<slab::Ref<ScheduledIo>> { + let (address, shared) = self.io_dispatch.allocate().ok_or_else(|| { io::Error::new( io::ErrorKind::Other, "reactor at max registered I/O resources", ) })?; - self.n_sources.fetch_add(1, SeqCst); + let token = GENERATION.pack(shared.generation(), ADDRESS.pack(address.as_usize(), 0)); - self.io.register( - source, - mio::Token(address.to_usize()), - ready, - mio::PollOpt::edge(), - )?; + self.registry + .register(source, mio::Token(token), interest)?; - Ok(address) + Ok(shared) } /// Deregisters an I/O resource from the reactor. - pub(super) fn deregister_source(&self, source: &dyn Evented) -> io::Result<()> { - self.io.deregister(source) - } - - pub(super) fn drop_source(&self, address: Address) { - self.io_dispatch.remove(address); - self.n_sources.fetch_sub(1, SeqCst); - } - - /// Registers interest in the I/O resource associated with `token`. - pub(super) fn register(&self, token: Address, dir: Direction, w: Waker) { - let sched = self - .io_dispatch - .get(token) - .unwrap_or_else(|| panic!("IO resource for token {:?} does not exist!", token)); - - let waker = match dir { - Direction::Read => &sched.reader, - Direction::Write => &sched.writer, - }; - - waker.register(w); + pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> { + self.registry.deregister(source) } } impl Direction { - pub(super) fn mask(self) -> mio::Ready { + pub(super) fn mask(self) -> Ready { match self { - Direction::Read => { - // Everything except writable is signaled through read. - mio::Ready::all() - mio::Ready::writable() - } - Direction::Write => mio::Ready::writable() | platform::hup() | platform::error(), - } - } -} - -#[cfg(all(test, loom))] -mod tests { - use super::*; - use loom::thread; - - // No-op `Evented` impl just so we can have something to pass to `add_source`. - struct NotEvented; - - impl Evented for NotEvented { - fn register( - &self, - _: &mio::Poll, - _: mio::Token, - _: mio::Ready, - _: mio::PollOpt, - ) -> io::Result<()> { - Ok(()) - } - - fn reregister( - &self, - _: &mio::Poll, - _: mio::Token, - _: mio::Ready, - _: mio::PollOpt, - ) -> io::Result<()> { - Ok(()) - } - - fn deregister(&self, _: &mio::Poll) -> io::Result<()> { - Ok(()) + Direction::Read => Ready::READABLE | Ready::READ_CLOSED, + Direction::Write => Ready::WRITABLE | Ready::WRITE_CLOSED, } } - - #[test] - fn tokens_unique_when_dropped() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let thread = thread::spawn(move || { - inner2.drop_source(token_1); - }); - - let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } - - #[test] - fn tokens_unique_when_dropped_on_full_page() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - // add sources to fill up the first page so that the dropped index - // may be reused. - for _ in 0..31 { - inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - } - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let thread = thread::spawn(move || { - inner2.drop_source(token_1); - }); - - let token_2 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } - - #[test] - fn tokens_unique_concurrent_add() { - loom::model(|| { - let reactor = Driver::new().unwrap(); - let inner = reactor.inner; - let inner2 = inner.clone(); - - let thread = thread::spawn(move || { - let token_2 = inner2.add_source(&NotEvented, mio::Ready::all()).unwrap(); - token_2 - }); - - let token_1 = inner.add_source(&NotEvented, mio::Ready::all()).unwrap(); - let token_2 = thread.join().unwrap(); - - assert!(token_1 != token_2); - }) - } } diff --git a/src/io/driver/ready.rs b/src/io/driver/ready.rs new file mode 100644 index 0000000..8b556e9 --- /dev/null +++ b/src/io/driver/ready.rs @@ -0,0 +1,187 @@ +use std::fmt; +use std::ops; + +const READABLE: usize = 0b0_01; +const WRITABLE: usize = 0b0_10; +const READ_CLOSED: usize = 0b0_0100; +const WRITE_CLOSED: usize = 0b0_1000; + +/// A set of readiness event kinds. +/// +/// `Ready` is set of operation descriptors indicating which kind of an +/// operation is ready to be performed. +/// +/// This struct only represents portable event kinds. Portable events are +/// events that can be raised on any platform while guaranteeing no false +/// positives. +#[derive(Clone, Copy, PartialEq, PartialOrd)] +pub(crate) struct Ready(usize); + +impl Ready { + /// Returns the empty `Ready` set. + pub(crate) const EMPTY: Ready = Ready(0); + + /// Returns a `Ready` representing readable readiness. + pub(crate) const READABLE: Ready = Ready(READABLE); + + /// Returns a `Ready` representing writable readiness. + pub(crate) const WRITABLE: Ready = Ready(WRITABLE); + + /// Returns a `Ready` representing read closed readiness. + pub(crate) const READ_CLOSED: Ready = Ready(READ_CLOSED); + + /// Returns a `Ready` representing write closed readiness. + pub(crate) const WRITE_CLOSED: Ready = Ready(WRITE_CLOSED); + + /// Returns a `Ready` representing readiness for all operations. + pub(crate) const ALL: Ready = Ready(READABLE | WRITABLE | READ_CLOSED | WRITE_CLOSED); + + pub(crate) fn from_mio(event: &mio::event::Event) -> Ready { + let mut ready = Ready::EMPTY; + + if event.is_readable() { + ready |= Ready::READABLE; + } + + if event.is_writable() { + ready |= Ready::WRITABLE; + } + + if event.is_read_closed() { + ready |= Ready::READ_CLOSED; + } + + if event.is_write_closed() { + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + /// Returns true if `Ready` is the empty set + pub(crate) fn is_empty(self) -> bool { + self == Ready::EMPTY + } + + /// Returns true if the value includes readable readiness + pub(crate) fn is_readable(self) -> bool { + self.contains(Ready::READABLE) || self.is_read_closed() + } + + /// Returns true if the value includes writable readiness + pub(crate) fn is_writable(self) -> bool { + self.contains(Ready::WRITABLE) || self.is_write_closed() + } + + /// Returns true if the value includes read closed readiness + pub(crate) fn is_read_closed(self) -> bool { + self.contains(Ready::READ_CLOSED) + } + + /// Returns true if the value includes write closed readiness + pub(crate) fn is_write_closed(self) -> bool { + self.contains(Ready::WRITE_CLOSED) + } + + /// Returns true if `self` is a superset of `other`. + /// + /// `other` may represent more than one readiness operations, in which case + /// the function only returns true if `self` contains all readiness + /// specified in `other`. + pub(crate) fn contains<T: Into<Self>>(self, other: T) -> bool { + let other = other.into(); + (self & other) == other + } + + /// Create a `Ready` instance using the given `usize` representation. + /// + /// The `usize` representation must have been obtained from a call to + /// `Readiness::as_usize`. + /// + /// This function is mainly provided to allow the caller to get a + /// readiness value from an `AtomicUsize`. + pub(crate) fn from_usize(val: usize) -> Ready { + Ready(val & Ready::ALL.as_usize()) + } + + /// Returns a `usize` representation of the `Ready` value. + /// + /// This function is mainly provided to allow the caller to store a + /// readiness value in an `AtomicUsize`. + pub(crate) fn as_usize(self) -> usize { + self.0 + } +} + +cfg_io_readiness! { + impl Ready { + pub(crate) fn from_interest(interest: mio::Interest) -> Ready { + let mut ready = Ready::EMPTY; + + if interest.is_readable() { + ready |= Ready::READABLE; + ready |= Ready::READ_CLOSED; + } + + if interest.is_writable() { + ready |= Ready::WRITABLE; + ready |= Ready::WRITE_CLOSED; + } + + ready + } + + pub(crate) fn intersection(self, interest: mio::Interest) -> Ready { + Ready(self.0 & Ready::from_interest(interest).0) + } + + pub(crate) fn satisfies(self, interest: mio::Interest) -> bool { + self.0 & Ready::from_interest(interest).0 != 0 + } + } +} + +impl<T: Into<Ready>> ops::BitOr<T> for Ready { + type Output = Ready; + + #[inline] + fn bitor(self, other: T) -> Ready { + Ready(self.0 | other.into().0) + } +} + +impl<T: Into<Ready>> ops::BitOrAssign<T> for Ready { + #[inline] + fn bitor_assign(&mut self, other: T) { + self.0 |= other.into().0; + } +} + +impl<T: Into<Ready>> ops::BitAnd<T> for Ready { + type Output = Ready; + + #[inline] + fn bitand(self, other: T) -> Ready { + Ready(self.0 & other.into().0) + } +} + +impl<T: Into<Ready>> ops::Sub<T> for Ready { + type Output = Ready; + + #[inline] + fn sub(self, other: T) -> Ready { + Ready(self.0 & !other.into().0) + } +} + +impl fmt::Debug for Ready { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Ready") + .field("is_readable", &self.is_readable()) + .field("is_writable", &self.is_writable()) + .field("is_read_closed", &self.is_read_closed()) + .field("is_write_closed", &self.is_write_closed()) + .finish() + } +} diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs index 7f6446e..b1354a0 100644 --- a/src/io/driver/scheduled_io.rs +++ b/src/io/driver/scheduled_io.rs @@ -1,47 +1,109 @@ -use crate::loom::future::AtomicWaker; +use super::{Direction, Ready, ReadyEvent, Tick}; use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; use crate::util::bit; -use crate::util::slab::{Address, Entry, Generation}; +use crate::util::slab::Entry; -use std::sync::atomic::Ordering::{AcqRel, Acquire, SeqCst}; +use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::task::{Context, Poll, Waker}; +cfg_io_readiness! { + use crate::util::linked_list::{self, LinkedList}; + + use std::cell::UnsafeCell; + use std::future::Future; + use std::marker::PhantomPinned; + use std::pin::Pin; + use std::ptr::NonNull; +} + +/// Stored in the I/O driver resource slab. #[derive(Debug)] pub(crate) struct ScheduledIo { + /// Packs the resource's readiness with the resource's generation. readiness: AtomicUsize, - pub(crate) reader: AtomicWaker, - pub(crate) writer: AtomicWaker, + + waiters: Mutex<Waiters>, } -const PACK: bit::Pack = bit::Pack::most_significant(Generation::WIDTH); +cfg_io_readiness! { + type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; +} -impl Entry for ScheduledIo { - fn generation(&self) -> Generation { - unpack_generation(self.readiness.load(SeqCst)) +#[derive(Debug, Default)] +struct Waiters { + #[cfg(feature = "net")] + /// List of all current waiters + list: WaitList, + + /// Waker used for AsyncRead + reader: Option<Waker>, + + /// Waker used for AsyncWrite + writer: Option<Waker>, +} + +cfg_io_readiness! { + #[derive(Debug)] + struct Waiter { + pointers: linked_list::Pointers<Waiter>, + + /// The waker for this task + waker: Option<Waker>, + + /// The interest this waiter is waiting on + interest: mio::Interest, + + is_ready: bool, + + /// Should never be `!Unpin` + _p: PhantomPinned, } - fn reset(&self, generation: Generation) -> bool { - let mut current = self.readiness.load(Acquire); + /// Future returned by `readiness()` + struct Readiness<'a> { + scheduled_io: &'a ScheduledIo, - loop { - if unpack_generation(current) != generation { - return false; - } + state: State, - let next = PACK.pack(generation.next().to_usize(), 0); + /// Entry in the waiter `LinkedList`. + waiter: UnsafeCell<Waiter>, + } - match self - .readiness - .compare_exchange(current, next, AcqRel, Acquire) - { - Ok(_) => break, - Err(actual) => current = actual, - } - } + enum State { + Init, + Waiting, + Done, + } +} + +// The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness. +// +// | reserved | generation | driver tick | readinesss | +// |----------+------------+--------------+------------| +// | 1 bit | 7 bits + 8 bits + 16 bits | + +const READINESS: bit::Pack = bit::Pack::least_significant(16); + +const TICK: bit::Pack = READINESS.then(8); - drop(self.reader.take_waker()); - drop(self.writer.take_waker()); +const GENERATION: bit::Pack = TICK.then(7); - true +#[test] +fn test_generations_assert_same() { + assert_eq!(super::GENERATION, GENERATION); +} + +// ===== impl ScheduledIo ===== + +impl Entry for ScheduledIo { + fn reset(&self) { + let state = self.readiness.load(Acquire); + + let generation = GENERATION.unpack(state); + let next = GENERATION.pack_lossy(generation + 1, 0); + + self.readiness.store(next, Release); } } @@ -49,31 +111,14 @@ impl Default for ScheduledIo { fn default() -> ScheduledIo { ScheduledIo { readiness: AtomicUsize::new(0), - reader: AtomicWaker::new(), - writer: AtomicWaker::new(), + waiters: Mutex::new(Default::default()), } } } impl ScheduledIo { - #[cfg(all(test, loom))] - /// Returns the current readiness value of this `ScheduledIo`, if the - /// provided `token` is still a valid access. - /// - /// # Returns - /// - /// If the given token's generation no longer matches the `ScheduledIo`'s - /// generation, then the corresponding IO resource has been removed and - /// replaced with a new resource. In that case, this method returns `None`. - /// Otherwise, this returns the current readiness. - pub(crate) fn get_readiness(&self, address: Address) -> Option<usize> { - let ready = self.readiness.load(Acquire); - - if unpack_generation(ready) != address.generation() { - return None; - } - - Some(ready & !PACK.mask()) + pub(crate) fn generation(&self) -> usize { + GENERATION.unpack(self.readiness.load(Acquire)) } /// Sets the readiness on this `ScheduledIo` by invoking the given closure on @@ -81,6 +126,8 @@ impl ScheduledIo { /// /// # Arguments /// - `token`: the token for this `ScheduledIo`. + /// - `tick`: whether setting the tick or trying to clear readiness for a + /// specific tick. /// - `f`: a closure returning a new readiness value given the previous /// readiness. /// @@ -90,52 +137,354 @@ impl ScheduledIo { /// generation, then the corresponding IO resource has been removed and /// replaced with a new resource. In that case, this method returns `Err`. /// Otherwise, this returns the previous readiness. - pub(crate) fn set_readiness( + pub(super) fn set_readiness( &self, - address: Address, - f: impl Fn(usize) -> usize, - ) -> Result<usize, ()> { - let generation = address.generation(); - + token: Option<usize>, + tick: Tick, + f: impl Fn(Ready) -> Ready, + ) -> Result<(), ()> { let mut current = self.readiness.load(Acquire); loop { - // Check that the generation for this access is still the current - // one. - if unpack_generation(current) != generation { - return Err(()); + let current_generation = GENERATION.unpack(current); + + if let Some(token) = token { + // Check that the generation for this access is still the + // current one. + if GENERATION.unpack(token) != current_generation { + return Err(()); + } } - // Mask out the generation bits so that the modifying function - // doesn't see them. - let current_readiness = current & mio::Ready::all().as_usize(); + + // Mask out the tick/generation bits so that the modifying + // function doesn't see them. + let current_readiness = Ready::from_usize(current); let new = f(current_readiness); - debug_assert!( - new <= !PACK.max_value(), - "new readiness value would overwrite generation bits!" - ); - - match self.readiness.compare_exchange( - current, - PACK.pack(generation.to_usize(), new), - AcqRel, - Acquire, - ) { - Ok(_) => return Ok(current), + let packed = match tick { + Tick::Set(t) => TICK.pack(t as usize, new.as_usize()), + Tick::Clear(t) => { + if TICK.unpack(current) as u8 != t { + // Trying to clear readiness with an old event! + return Err(()); + } + + TICK.pack(t as usize, new.as_usize()) + } + }; + + let next = GENERATION.pack(current_generation, packed); + + match self + .readiness + .compare_exchange(current, next, AcqRel, Acquire) + { + Ok(_) => return Ok(()), // we lost the race, retry! Err(actual) => current = actual, } } } + + /// Notifies all pending waiters that have registered interest in `ready`. + /// + /// There may be many waiters to notify. Waking the pending task **must** be + /// done from outside of the lock otherwise there is a potential for a + /// deadlock. + /// + /// A stack array of wakers is created and filled with wakers to notify, the + /// lock is released, and the wakers are notified. Because there may be more + /// than 32 wakers to notify, if the stack array fills up, the lock is + /// released, the array is cleared, and the iteration continues. + pub(super) fn wake(&self, ready: Ready) { + const NUM_WAKERS: usize = 32; + + let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default(); + let mut curr = 0; + + let mut waiters = self.waiters.lock(); + + // check for AsyncRead slot + if ready.is_readable() { + if let Some(waker) = waiters.reader.take() { + wakers[curr] = Some(waker); + curr += 1; + } + } + + // check for AsyncWrite slot + if ready.is_writable() { + if let Some(waker) = waiters.writer.take() { + wakers[curr] = Some(waker); + curr += 1; + } + } + + #[cfg(feature = "net")] + 'outer: loop { + let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest)); + + while curr < NUM_WAKERS { + match iter.next() { + Some(waiter) => { + let waiter = unsafe { &mut *waiter.as_ptr() }; + + if let Some(waker) = waiter.waker.take() { + waiter.is_ready = true; + wakers[curr] = Some(waker); + curr += 1; + } + } + None => { + break 'outer; + } + } + } + + drop(waiters); + + for waker in wakers.iter_mut().take(curr) { + waker.take().unwrap().wake(); + } + + curr = 0; + + // Acquire the lock again. + waiters = self.waiters.lock(); + } + + // Release the lock before notifying + drop(waiters); + + for waker in wakers.iter_mut().take(curr) { + waker.take().unwrap().wake(); + } + } + + /// Poll version of checking readiness for a certain direction. + /// + /// These are to support `AsyncRead` and `AsyncWrite` polling methods, + /// which cannot use the `async fn` version. This uses reserved reader + /// and writer slots. + pub(in crate::io) fn poll_readiness( + &self, + cx: &mut Context<'_>, + direction: Direction, + ) -> Poll<ReadyEvent> { + let curr = self.readiness.load(Acquire); + + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + + if ready.is_empty() { + // Update the task info + let mut waiters = self.waiters.lock(); + let slot = match direction { + Direction::Read => &mut waiters.reader, + Direction::Write => &mut waiters.writer, + }; + *slot = Some(cx.waker().clone()); + + // Try again, in case the readiness was changed while we were + // taking the waiters lock + let curr = self.readiness.load(Acquire); + let ready = direction.mask() & Ready::from_usize(READINESS.unpack(curr)); + if ready.is_empty() { + Poll::Pending + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } else { + Poll::Ready(ReadyEvent { + tick: TICK.unpack(curr) as u8, + ready, + }) + } + } + + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + // This consumes the current readiness state **except** for closed + // states. Closed states are excluded because they are final states. + let mask_no_closed = event.ready - Ready::READ_CLOSED - Ready::WRITE_CLOSED; + + // result isn't important + let _ = self.set_readiness(None, Tick::Clear(event.tick), |curr| curr - mask_no_closed); + } } impl Drop for ScheduledIo { fn drop(&mut self) { - self.writer.wake(); - self.reader.wake(); + self.wake(Ready::ALL); } } -fn unpack_generation(src: usize) -> Generation { - Generation::new(PACK.unpack(src)) +unsafe impl Send for ScheduledIo {} +unsafe impl Sync for ScheduledIo {} + +cfg_io_readiness! { + impl ScheduledIo { + /// An async version of `poll_readiness` which uses a linked list of wakers + pub(crate) async fn readiness(&self, interest: mio::Interest) -> ReadyEvent { + self.readiness_fut(interest).await + } + + // This is in a separate function so that the borrow checker doesn't think + // we are borrowing the `UnsafeCell` possibly over await boundaries. + // + // Go figure. + fn readiness_fut(&self, interest: mio::Interest) -> Readiness<'_> { + Readiness { + scheduled_io: self, + state: State::Init, + waiter: UnsafeCell::new(Waiter { + pointers: linked_list::Pointers::new(), + waker: None, + is_ready: false, + interest, + _p: PhantomPinned, + }), + } + } + } + + unsafe impl linked_list::Link for Waiter { + type Handle = NonNull<Waiter>; + type Target = Waiter; + + fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> { + *handle + } + + unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> { + ptr + } + + unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> { + NonNull::from(&mut target.as_mut().pointers) + } + } + + // ===== impl Readiness ===== + + impl Future for Readiness<'_> { + type Output = ReadyEvent; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::sync::atomic::Ordering::SeqCst; + + let (scheduled_io, state, waiter) = unsafe { + let me = self.get_unchecked_mut(); + (&me.scheduled_io, &mut me.state, &me.waiter) + }; + + loop { + match *state { + State::Init => { + // Optimistically check existing readiness + let curr = scheduled_io.readiness.load(SeqCst); + let ready = Ready::from_usize(READINESS.unpack(curr)); + + // Safety: `waiter.interest` never changes + let interest = unsafe { (*waiter.get()).interest }; + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { ready, tick }); + } + + // Wasn't ready, take the lock (and check again while locked). + let mut waiters = scheduled_io.waiters.lock(); + + let curr = scheduled_io.readiness.load(SeqCst); + let ready = Ready::from_usize(READINESS.unpack(curr)); + let ready = ready.intersection(interest); + + if !ready.is_empty() { + // Currently ready! + let tick = TICK.unpack(curr) as u8; + *state = State::Done; + return Poll::Ready(ReadyEvent { ready, tick }); + } + + // Not ready even after locked, insert into list... + + // Safety: called while locked + unsafe { + (*waiter.get()).waker = Some(cx.waker().clone()); + } + + // Insert the waiter into the linked list + // + // safety: pointers from `UnsafeCell` are never null. + waiters + .list + .push_front(unsafe { NonNull::new_unchecked(waiter.get()) }); + *state = State::Waiting; + } + State::Waiting => { + // Currently in the "Waiting" state, implying the caller has + // a waiter stored in the waiter list (guarded by + // `notify.waiters`). In order to access the waker fields, + // we must hold the lock. + + let waiters = scheduled_io.waiters.lock(); + + // Safety: called while locked + let w = unsafe { &mut *waiter.get() }; + + if w.is_ready { + // Our waker has been notified. + *state = State::Done; + } else { + // Update the waker, if necessary. + if !w.waker.as_ref().unwrap().will_wake(cx.waker()) { + w.waker = Some(cx.waker().clone()); + } + + return Poll::Pending; + } + + // Explicit drop of the lock to indicate the scope that the + // lock is held. Because holding the lock is required to + // ensure safe access to fields not held within the lock, it + // is helpful to visualize the scope of the critical + // section. + drop(waiters); + } + State::Done => { + let tick = TICK.unpack(scheduled_io.readiness.load(Acquire)) as u8; + + // Safety: State::Done means it is no longer shared + let w = unsafe { &mut *waiter.get() }; + + return Poll::Ready(ReadyEvent { + tick, + ready: Ready::from_interest(w.interest), + }); + } + } + } + } + } + + impl Drop for Readiness<'_> { + fn drop(&mut self) { + let mut waiters = self.scheduled_io.waiters.lock(); + + // Safety: `waiter` is only ever stored in `waiters` + unsafe { + waiters + .list + .remove(NonNull::new_unchecked(self.waiter.get())) + }; + } + } + + unsafe impl Send for Readiness<'_> {} + unsafe impl Sync for Readiness<'_> {} } diff --git a/src/io/mod.rs b/src/io/mod.rs index 7b00556..9191bbc 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -162,8 +162,8 @@ //! //! # `std` re-exports //! -//! Additionally, [`Error`], [`ErrorKind`], and [`Result`] are re-exported -//! from `std::io` for ease of use. +//! Additionally, [`Error`], [`ErrorKind`], [`Result`], and [`SeekFrom`] are +//! re-exported from `std::io` for ease of use. //! //! [`AsyncRead`]: trait@AsyncRead //! [`AsyncWrite`]: trait@AsyncWrite @@ -176,6 +176,7 @@ //! [`ErrorKind`]: enum@ErrorKind //! [`Result`]: type@Result //! [`Read`]: std::io::Read +//! [`SeekFrom`]: enum@SeekFrom //! [`Sink`]: https://docs.rs/futures/0.3/futures/sink/trait.Sink.html //! [`Stream`]: crate::stream::Stream //! [`Write`]: std::io::Write @@ -187,7 +188,6 @@ mod async_buf_read; pub use self::async_buf_read::AsyncBufRead; mod async_read; - pub use self::async_read::AsyncRead; mod async_seek; @@ -196,17 +196,27 @@ pub use self::async_seek::AsyncSeek; mod async_write; pub use self::async_write::AsyncWrite; +mod read_buf; +pub use self::read_buf::ReadBuf; + +// Re-export some types from `std::io` so that users don't have to deal +// with conflicts when `use`ing `tokio::io` and `std::io`. +#[doc(no_inline)] +pub use std::io::{Error, ErrorKind, Result, SeekFrom}; + cfg_io_driver! { pub(crate) mod driver; mod poll_evented; - pub use poll_evented::PollEvented; + #[cfg(not(loom))] + pub(crate) use poll_evented::PollEvented; mod registration; - pub use registration::Registration; } cfg_io_std! { + mod stdio_common; + mod stderr; pub use stderr::{stderr, Stderr}; @@ -222,21 +232,11 @@ cfg_io_util! { pub use split::{split, ReadHalf, WriteHalf}; pub(crate) mod seek; - pub use self::seek::Seek; - pub(crate) mod util; pub use util::{ - copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, - BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, + copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take, }; - - cfg_stream! { - pub use util::{stream_reader, StreamReader}; - } - - // Re-export io::Error so that users don't have to deal with conflicts when - // `use`ing `tokio::io` and `std::io`. - pub use std::io::{Error, ErrorKind, Result}; } cfg_not_io_util! { @@ -249,7 +249,7 @@ cfg_io_blocking! { /// Types in this module can be mocked out in tests. mod sys { // TODO: don't rename - pub(crate) use crate::runtime::spawn_blocking as run; - pub(crate) use crate::task::JoinHandle as Blocking; + pub(crate) use crate::blocking::spawn_blocking as run; + pub(crate) use crate::blocking::JoinHandle as Blocking; } } diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs index 5295bd7..66a2634 100644 --- a/src/io/poll_evented.rs +++ b/src/io/poll_evented.rs @@ -1,13 +1,12 @@ -use crate::io::driver::platform; -use crate::io::{AsyncRead, AsyncWrite, Registration}; +use crate::io::driver::{Direction, Handle, ReadyEvent}; +use crate::io::registration::Registration; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use mio::event::Evented; +use mio::event::Source; use std::fmt; use std::io::{self, Read, Write}; use std::marker::Unpin; use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::Relaxed; use std::task::{Context, Poll}; cfg_io_driver! { @@ -53,37 +52,6 @@ cfg_io_driver! { /// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and /// [`clear_read_ready`]. /// - /// ```rust - /// use tokio::io::PollEvented; - /// - /// use futures::ready; - /// use mio::Ready; - /// use mio::net::{TcpStream, TcpListener}; - /// use std::io; - /// use std::task::{Context, Poll}; - /// - /// struct MyListener { - /// poll_evented: PollEvented<TcpListener>, - /// } - /// - /// impl MyListener { - /// pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<TcpStream, io::Error>> { - /// let ready = Ready::readable(); - /// - /// ready!(self.poll_evented.poll_read_ready(cx, ready))?; - /// - /// match self.poll_evented.get_ref().accept() { - /// Ok((socket, _)) => Poll::Ready(Ok(socket)), - /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - /// self.poll_evented.clear_read_ready(cx, ready)?; - /// Poll::Pending - /// } - /// Err(e) => Poll::Ready(Err(e)), - /// } - /// } - /// } - /// ``` - /// /// ## Platform-specific events /// /// `PollEvented` also allows receiving platform-specific `mio::Ready` events. @@ -101,70 +69,15 @@ cfg_io_driver! { /// [`clear_write_ready`]: method@Self::clear_write_ready /// [`poll_read_ready`]: method@Self::poll_read_ready /// [`poll_write_ready`]: method@Self::poll_write_ready - pub struct PollEvented<E: Evented> { + pub(crate) struct PollEvented<E: Source> { io: Option<E>, - inner: Inner, + registration: Registration, } } -struct Inner { - registration: Registration, - - /// Currently visible read readiness - read_readiness: AtomicUsize, - - /// Currently visible write readiness - write_readiness: AtomicUsize, -} - // ===== impl PollEvented ===== -macro_rules! poll_ready { - ($me:expr, $mask:expr, $cache:ident, $take:ident, $poll:expr) => {{ - // Load cached & encoded readiness. - let mut cached = $me.inner.$cache.load(Relaxed); - let mask = $mask | platform::hup() | platform::error(); - - // See if the current readiness matches any bits. - let mut ret = mio::Ready::from_usize(cached) & $mask; - - if ret.is_empty() { - // Readiness does not match, consume the registration's readiness - // stream. This happens in a loop to ensure that the stream gets - // drained. - loop { - let ready = match $poll? { - Poll::Ready(v) => v, - Poll::Pending => return Poll::Pending, - }; - cached |= ready.as_usize(); - - // Update the cache store - $me.inner.$cache.store(cached, Relaxed); - - ret |= ready & mask; - - if !ret.is_empty() { - return Poll::Ready(Ok(ret)); - } - } - } else { - // Check what's new with the registration stream. This will not - // request to be notified - if let Some(ready) = $me.inner.registration.$take()? { - cached |= ready.as_usize(); - $me.inner.$cache.store(cached, Relaxed); - } - - Poll::Ready(Ok(mio::Ready::from_usize(cached))) - } - }}; -} - -impl<E> PollEvented<E> -where - E: Evented, -{ +impl<E: Source> PollEvented<E> { /// Creates a new `PollEvented` associated with the default reactor. /// /// # Panics @@ -173,71 +86,57 @@ where /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new(io: E) -> io::Result<Self> { - PollEvented::new_with_ready(io, mio::Ready::all()) + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new(io: E) -> io::Result<Self> { + PollEvented::new_with_interest(io, mio::Interest::READABLE | mio::Interest::WRITABLE) } - /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Ready` - /// state. `new_with_ready` should be used over `new` when you need control over the readiness + /// Creates a new `PollEvented` associated with the default reactor, for specific `mio::Interest` + /// state. `new_with_interest` should be used over `new` when you need control over the readiness /// state, such as when a file descriptor only allows reads. This does not add `hup` or `error` /// so if you are interested in those states, you will need to add them to the readiness state /// passed to this function. /// - /// An example to listen to read only - /// - /// ```rust - /// ##[cfg(unix)] - /// mio::Ready::from_usize( - /// mio::Ready::readable().as_usize() - /// | mio::unix::UnixReady::error().as_usize() - /// | mio::unix::UnixReady::hup().as_usize() - /// ); - /// ``` - /// /// # Panics /// /// This function panics if thread-local runtime is not set. /// /// The runtime is usually set implicitly when this function is called /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new_with_ready(io: E, ready: mio::Ready) -> io::Result<Self> { - let registration = Registration::new_with_ready(&io, ready)?; + /// explicitly with [`Runtime::enter`](crate::runtime::Runtime::enter) function. + #[cfg_attr(feature = "signal", allow(unused))] + pub(crate) fn new_with_interest(io: E, interest: mio::Interest) -> io::Result<Self> { + Self::new_with_interest_and_handle(io, interest, Handle::current()) + } + + pub(crate) fn new_with_interest_and_handle( + mut io: E, + interest: mio::Interest, + handle: Handle, + ) -> io::Result<Self> { + let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?; Ok(Self { io: Some(io), - inner: Inner { - registration, - read_readiness: AtomicUsize::new(0), - write_readiness: AtomicUsize::new(0), - }, + registration, }) } /// Returns a shared reference to the underlying I/O object this readiness /// stream is wrapping. - pub fn get_ref(&self) -> &E { + #[cfg(any(feature = "net", feature = "process", feature = "signal"))] + pub(crate) fn get_ref(&self) -> &E { self.io.as_ref().unwrap() } /// Returns a mutable reference to the underlying I/O object this readiness /// stream is wrapping. - pub fn get_mut(&mut self) -> &mut E { + pub(crate) fn get_mut(&mut self) -> &mut E { self.io.as_mut().unwrap() } - /// Consumes self, returning the inner I/O object - /// - /// This function will deregister the I/O resource from the reactor before - /// returning. If the deregistration operation fails, an error is returned. - /// - /// Note that deregistering does not guarantee that the I/O resource can be - /// registered with a different reactor. Some I/O resource types can only be - /// associated with a single reactor instance for their lifetime. - pub fn into_inner(mut self) -> io::Result<E> { - let io = self.io.take().unwrap(); - self.inner.registration.deregister(&io)?; - Ok(io) + pub(crate) fn clear_readiness(&self, event: ReadyEvent) { + self.registration.clear_readiness(event); } /// Checks the I/O resource's read readiness state. @@ -266,51 +165,8 @@ where /// /// This method may not be called concurrently. It takes `&self` to allow /// calling it concurrently with `poll_write_ready`. - pub fn poll_read_ready( - &self, - cx: &mut Context<'_>, - mask: mio::Ready, - ) -> Poll<io::Result<mio::Ready>> { - assert!(!mask.is_writable(), "cannot poll for write readiness"); - poll_ready!( - self, - mask, - read_readiness, - take_read_ready, - self.inner.registration.poll_read_ready(cx) - ) - } - - /// Clears the I/O resource's read readiness state and registers the current - /// task to be notified once a read readiness event is received. - /// - /// After calling this function, `poll_read_ready` will return - /// `Poll::Pending` until a new read readiness event has been received. - /// - /// The `mask` argument specifies the readiness bits to clear. This may not - /// include `writable` or `hup`. - /// - /// # Panics - /// - /// This function panics if: - /// - /// * `ready` includes writable or HUP - /// * called from outside of a task context. - pub fn clear_read_ready(&self, cx: &mut Context<'_>, ready: mio::Ready) -> io::Result<()> { - // Cannot clear write readiness - assert!(!ready.is_writable(), "cannot clear write readiness"); - assert!(!platform::is_hup(ready), "cannot clear HUP readiness"); - - self.inner - .read_readiness - .fetch_and(!ready.as_usize(), Relaxed); - - if self.poll_read_ready(cx, ready)?.is_ready() { - // Notify the current task - cx.waker().wake_by_ref(); - } - - Ok(()) + pub(crate) fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.registration.poll_readiness(cx, Direction::Read) } /// Checks the I/O resource's write readiness state. @@ -337,100 +193,95 @@ where /// /// This method may not be called concurrently. It takes `&self` to allow /// calling it concurrently with `poll_read_ready`. - pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - poll_ready!( - self, - mio::Ready::writable(), - write_readiness, - take_write_ready, - self.inner.registration.poll_write_ready(cx) - ) + pub(crate) fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<ReadyEvent>> { + self.registration.poll_readiness(cx, Direction::Write) } +} - /// Resets the I/O resource's write readiness state and registers the current - /// task to be notified once a write readiness event is received. - /// - /// This only clears writable readiness. HUP (on platforms that support HUP) - /// cannot be cleared as it is a final state. - /// - /// After calling this function, `poll_write_ready(Ready::writable())` will - /// return `NotReady` until a new write readiness event has been received. - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn clear_write_ready(&self, cx: &mut Context<'_>) -> io::Result<()> { - let ready = mio::Ready::writable(); +cfg_io_readiness! { + impl<E: Source> PollEvented<E> { + pub(crate) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> { + self.registration.readiness(interest).await + } - self.inner - .write_readiness - .fetch_and(!ready.as_usize(), Relaxed); + pub(crate) async fn async_io<F, R>(&self, interest: mio::Interest, mut op: F) -> io::Result<R> + where + F: FnMut(&E) -> io::Result<R>, + { + loop { + let event = self.readiness(interest).await?; - if self.poll_write_ready(cx)?.is_ready() { - // Notify the current task - cx.waker().wake_by_ref(); + match op(self.get_ref()) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.clear_readiness(event); + } + x => return x, + } + } } - - Ok(()) } } // ===== Read / Write impls ===== -impl<E> AsyncRead for PollEvented<E> -where - E: Evented + Read + Unpin, -{ +impl<E: Source + Read + Unpin> AsyncRead for PollEvented<E> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - ready!(self.poll_read_ready(cx, mio::Ready::readable()))?; - - let r = (*self).get_mut().read(buf); + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + loop { + let ev = ready!(self.poll_read_ready(cx))?; + + // We can't assume the `Read` won't look at the read buffer, + // so we have to force initialization here. + let r = (*self).get_mut().read(buf.initialize_unfilled()); + + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - if is_wouldblock(&r) { - self.clear_read_ready(cx, mio::Ready::readable())?; - return Poll::Pending; + return Poll::Ready(r.map(|n| { + buf.advance(n); + })); } - - Poll::Ready(r) } } -impl<E> AsyncWrite for PollEvented<E> -where - E: Evented + Write + Unpin, -{ +impl<E: Source + Write + Unpin> AsyncWrite for PollEvented<E> { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { - ready!(self.poll_write_ready(cx))?; + loop { + let ev = ready!(self.poll_write_ready(cx))?; - let r = (*self).get_mut().write(buf); + let r = (*self).get_mut().write(buf); - if is_wouldblock(&r) { - self.clear_write_ready(cx)?; - return Poll::Pending; - } + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - Poll::Ready(r) + return Poll::Ready(r); + } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - ready!(self.poll_write_ready(cx))?; + loop { + let ev = ready!(self.poll_write_ready(cx))?; - let r = (*self).get_mut().flush(); + let r = (*self).get_mut().flush(); - if is_wouldblock(&r) { - self.clear_write_ready(cx)?; - return Poll::Pending; - } + if is_wouldblock(&r) { + self.clear_readiness(ev); + continue; + } - Poll::Ready(r) + return Poll::Ready(r); + } } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { @@ -445,17 +296,17 @@ fn is_wouldblock<T>(r: &io::Result<T>) -> bool { } } -impl<E: Evented + fmt::Debug> fmt::Debug for PollEvented<E> { +impl<E: Source + fmt::Debug> fmt::Debug for PollEvented<E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PollEvented").field("io", &self.io).finish() } } -impl<E: Evented> Drop for PollEvented<E> { +impl<E: Source> Drop for PollEvented<E> { fn drop(&mut self) { - if let Some(io) = self.io.take() { + if let Some(mut io) = self.io.take() { // Ignore errors - let _ = self.inner.registration.deregister(&io); + let _ = self.registration.deregister(&mut io); } } } diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs new file mode 100644 index 0000000..b64d95c --- /dev/null +++ b/src/io/read_buf.rs @@ -0,0 +1,261 @@ +// This lint claims ugly casting is somehow safer than transmute, but there's +// no evidence that is the case. Shush. +#![allow(clippy::transmute_ptr_to_ptr)] + +use std::fmt; +use std::mem::{self, MaybeUninit}; + +/// A wrapper around a byte buffer that is incrementally filled and initialized. +/// +/// This type is a sort of "double cursor". It tracks three regions in the +/// buffer: a region at the beginning of the buffer that has been logically +/// filled with data, a region that has been initialized at some point but not +/// yet logically filled, and a region at the end that is fully uninitialized. +/// The filled region is guaranteed to be a subset of the initialized region. +/// +/// In summary, the contents of the buffer can be visualized as: +/// +/// ```not_rust +/// [ capacity ] +/// [ filled | unfilled ] +/// [ initialized | uninitialized ] +/// ``` +pub struct ReadBuf<'a> { + buf: &'a mut [MaybeUninit<u8>], + filled: usize, + initialized: usize, +} + +impl<'a> ReadBuf<'a> { + /// Creates a new `ReadBuf` from a fully initialized buffer. + #[inline] + pub fn new(buf: &'a mut [u8]) -> ReadBuf<'a> { + let initialized = buf.len(); + let buf = unsafe { mem::transmute::<&mut [u8], &mut [MaybeUninit<u8>]>(buf) }; + ReadBuf { + buf, + filled: 0, + initialized, + } + } + + /// Creates a new `ReadBuf` from a fully uninitialized buffer. + /// + /// Use `assume_init` if part of the buffer is known to be already inintialized. + #[inline] + pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> { + ReadBuf { + buf, + filled: 0, + initialized: 0, + } + } + + /// Returns the total capacity of the buffer. + #[inline] + pub fn capacity(&self) -> usize { + self.buf.len() + } + + /// Returns a shared reference to the filled portion of the buffer. + #[inline] + pub fn filled(&self) -> &[u8] { + let slice = &self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) } + } + + /// Returns a mutable reference to the filled portion of the buffer. + #[inline] + pub fn filled_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.filled]; + // safety: filled describes how far into the buffer that the + // user has filled with bytes, so it's been initialized. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns a new `ReadBuf` comprised of the unfilled section up to `n`. + #[inline] + pub fn take(&mut self, n: usize) -> ReadBuf<'_> { + let max = std::cmp::min(self.remaining(), n); + // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } + } + + /// Returns a shared reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized(&self) -> &[u8] { + let slice = &self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_ref` when it is stable. + unsafe { mem::transmute::<&[MaybeUninit<u8>], &[u8]>(slice) } + } + + /// Returns a mutable reference to the initialized portion of the buffer. + /// + /// This includes the filled portion. + #[inline] + pub fn initialized_mut(&mut self) -> &mut [u8] { + let slice = &mut self.buf[..self.initialized]; + // safety: initialized describes how far into the buffer that the + // user has at some point initialized with bytes. + // TODO: This could use `MaybeUninit::slice_get_mut` when it is stable. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns a mutable reference to the unfilled part of the buffer without ensuring that it has been fully + /// initialized. + /// + /// # Safety + /// + /// The caller must not de-initialize portions of the buffer that have already been initialized. + #[inline] + pub unsafe fn unfilled_mut(&mut self) -> &mut [MaybeUninit<u8>] { + &mut self.buf[self.filled..] + } + + /// Returns a mutable reference to the unfilled part of the buffer, ensuring it is fully initialized. + /// + /// Since `ReadBuf` tracks the region of the buffer that has been initialized, this is effectively "free" after + /// the first use. + #[inline] + pub fn initialize_unfilled(&mut self) -> &mut [u8] { + self.initialize_unfilled_to(self.remaining()) + } + + /// Returns a mutable reference to the first `n` bytes of the unfilled part of the buffer, ensuring it is + /// fully initialized. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `n`. + #[inline] + pub fn initialize_unfilled_to(&mut self, n: usize) -> &mut [u8] { + assert!(self.remaining() >= n, "n overflows remaining"); + + // This can't overflow, otherwise the assert above would have failed. + let end = self.filled + n; + + if self.initialized < end { + unsafe { + self.buf[self.initialized..end] + .as_mut_ptr() + .write_bytes(0, end - self.initialized); + } + self.initialized = end; + } + + let slice = &mut self.buf[self.filled..end]; + // safety: just above, we checked that the end of the buf has + // been initialized to some value. + unsafe { mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(slice) } + } + + /// Returns the number of bytes at the end of the slice that have not yet been filled. + #[inline] + pub fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + /// Clears the buffer, resetting the filled region to empty. + /// + /// The number of initialized bytes is not changed, and the contents of the buffer are not modified. + #[inline] + pub fn clear(&mut self) { + self.filled = 0; + } + + /// Advances the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the initialized region. + #[inline] + pub fn advance(&mut self, n: usize) { + let new = self.filled.checked_add(n).expect("filled overflow"); + self.set_filled(new); + } + + /// Sets the size of the filled region of the buffer. + /// + /// The number of initialized bytes is not changed. + /// + /// Note that this can be used to *shrink* the filled region of the buffer in addition to growing it (for + /// example, by a `AsyncRead` implementation that compresses data in-place). + /// + /// # Panics + /// + /// Panics if the filled region of the buffer would become larger than the intialized region. + #[inline] + pub fn set_filled(&mut self, n: usize) { + assert!( + n <= self.initialized, + "filled must not become larger than initialized" + ); + self.filled = n; + } + + /// Asserts that the first `n` unfilled bytes of the buffer are initialized. + /// + /// `ReadBuf` assumes that bytes are never de-initialized, so this method does nothing when called with fewer + /// bytes than are already known to be initialized. + /// + /// # Safety + /// + /// The caller must ensure that `n` unfilled bytes of the buffer have already been initialized. + #[inline] + pub unsafe fn assume_init(&mut self, n: usize) { + let new = self.filled + n; + if new > self.initialized { + self.initialized = new; + } + } + + /// Appends data to the buffer, advancing the written position and possibly also the initialized position. + /// + /// # Panics + /// + /// Panics if `self.remaining()` is less than `buf.len()`. + #[inline] + pub fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf[self.filled..end] + .as_mut_ptr() + .cast::<u8>() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.initialized < end { + self.initialized = end; + } + self.filled = end; + } +} + +impl fmt::Debug for ReadBuf<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("initialized", &self.initialized) + .field("capacity", &self.capacity()) + .finish() + } +} diff --git a/src/io/registration.rs b/src/io/registration.rs index 77fe6db..ce6cffd 100644 --- a/src/io/registration.rs +++ b/src/io/registration.rs @@ -1,7 +1,7 @@ -use crate::io::driver::{platform, Direction, Handle}; -use crate::util::slab::Address; +use crate::io::driver::{Direction, Handle, ReadyEvent, ScheduledIo}; +use crate::util::slab; -use mio::{self, Evented}; +use mio::event::Source; use std::io; use std::task::{Context, Poll}; @@ -38,74 +38,38 @@ cfg_io_driver! { /// [`poll_read_ready`]: method@Self::poll_read_ready` /// [`poll_write_ready`]: method@Self::poll_write_ready` #[derive(Debug)] - pub struct Registration { + pub(crate) struct Registration { + /// Handle to the associated driver. handle: Handle, - address: Address, + + /// Reference to state stored by the driver. + shared: slab::Ref<ScheduledIo>, } } +unsafe impl Send for Registration {} +unsafe impl Sync for Registration {} + // ===== impl Registration ===== impl Registration { - /// Registers the I/O resource with the default reactor. - /// - /// # Return - /// - /// - `Ok` if the registration happened successfully - /// - `Err` if an error was encountered during registration - /// - /// - /// # Panics - /// - /// This function panics if thread-local runtime is not set. - /// - /// The runtime is usually set implicitly when this function is called - /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new<T>(io: &T) -> io::Result<Registration> - where - T: Evented, - { - Registration::new_with_ready(io, mio::Ready::all()) - } - - /// Registers the I/O resource with the default reactor, for a specific `mio::Ready` state. - /// `new_with_ready` should be used over `new` when you need control over the readiness state, + /// Registers the I/O resource with the default reactor, for a specific `mio::Interest`. + /// `new_with_interest` should be used over `new` when you need control over the readiness state, /// such as when a file descriptor only allows reads. This does not add `hup` or `error` so if /// you are interested in those states, you will need to add them to the readiness state passed /// to this function. /// - /// An example to listen to read only - /// - /// ```rust - /// ##[cfg(unix)] - /// mio::Ready::from_usize( - /// mio::Ready::readable().as_usize() - /// | mio::unix::UnixReady::error().as_usize() - /// | mio::unix::UnixReady::hup().as_usize() - /// ); - /// ``` - /// /// # Return /// /// - `Ok` if the registration happened successfully /// - `Err` if an error was encountered during registration - /// - /// - /// # Panics - /// - /// This function panics if thread-local runtime is not set. - /// - /// The runtime is usually set implicitly when this function is called - /// from a future driven by a tokio runtime, otherwise runtime can be set - /// explicitly with [`Handle::enter`](crate::runtime::Handle::enter) function. - pub fn new_with_ready<T>(io: &T, ready: mio::Ready) -> io::Result<Registration> - where - T: Evented, - { - let handle = Handle::current(); - let address = if let Some(inner) = handle.inner() { - inner.add_source(io, ready)? + pub(crate) fn new_with_interest_and_handle( + io: &mut impl Source, + interest: mio::Interest, + handle: Handle, + ) -> io::Result<Registration> { + let shared = if let Some(inner) = handle.inner() { + inner.add_source(io, interest)? } else { return Err(io::Error::new( io::ErrorKind::Other, @@ -113,7 +77,7 @@ impl Registration { )); }; - Ok(Registration { handle, address }) + Ok(Registration { handle, shared }) } /// Deregisters the I/O resource from the reactor it is associated with. @@ -132,10 +96,7 @@ impl Registration { /// no longer result in notifications getting sent for this registration. /// /// `Err` is returned if an error is encountered. - pub fn deregister<T>(&mut self, io: &T) -> io::Result<()> - where - T: Evented, - { + pub(super) fn deregister(&mut self, io: &mut impl Source) -> io::Result<()> { let inner = match self.handle.inner() { Some(inner) => inner, None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")), @@ -143,198 +104,47 @@ impl Registration { inner.deregister_source(io) } - /// Polls for events on the I/O resource's read readiness stream. - /// - /// If the I/O resource receives a new read readiness event since the last - /// call to `poll_read_ready`, it is returned. If it has not, the current - /// task is notified once a new event is received. - /// - /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned, - /// the function will always return `Ready(HUP)`. This should be treated as - /// the end of the readiness stream. - /// - /// # Return value - /// - /// There are several possible return values: - /// - /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received - /// a new readiness event. The readiness value is included. - /// - /// * `Poll::Pending` means that no new readiness events have been received - /// since the last call to `poll_read_ready`. - /// - /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This could represent a permanent internal error for example. - /// - /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - let v = self.poll_ready(Direction::Read, Some(cx)).map_err(|e| { - coop.made_progress(); - e - })?; - match v { - Some(v) => { - coop.made_progress(); - Poll::Ready(Ok(v)) - } - None => Poll::Pending, - } - } - - /// Consume any pending read readiness event. - /// - /// This function is identical to [`poll_read_ready`] **except** that it - /// will not notify the current task when a new event is received. As such, - /// it is safe to call this function from outside of a task context. - /// - /// [`poll_read_ready`]: method@Self::poll_read_ready - pub fn take_read_ready(&self) -> io::Result<Option<mio::Ready>> { - self.poll_ready(Direction::Read, None) - } - - /// Polls for events on the I/O resource's write readiness stream. - /// - /// If the I/O resource receives a new write readiness event since the last - /// call to `poll_write_ready`, it is returned. If it has not, the current - /// task is notified once a new event is received. - /// - /// All events except `HUP` are [edge-triggered]. Once `HUP` is returned, - /// the function will always return `Ready(HUP)`. This should be treated as - /// the end of the readiness stream. - /// - /// # Return value - /// - /// There are several possible return values: - /// - /// * `Poll::Ready(Ok(readiness))` means that the I/O resource has received - /// a new readiness event. The readiness value is included. - /// - /// * `Poll::Pending` means that no new readiness events have been received - /// since the last call to `poll_write_ready`. - /// - /// * `Poll::Ready(Err(err))` means that the registration has encountered an - /// error. This could represent a permanent internal error for example. - /// - /// [edge-triggered]: struct@mio::Poll#edge-triggered-and-level-triggered - /// - /// # Panics - /// - /// This function will panic if called from outside of a task context. - pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<mio::Ready>> { - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); - - let v = self.poll_ready(Direction::Write, Some(cx)).map_err(|e| { - coop.made_progress(); - e - })?; - match v { - Some(v) => { - coop.made_progress(); - Poll::Ready(Ok(v)) - } - None => Poll::Pending, - } - } - - /// Consumes any pending write readiness event. - /// - /// This function is identical to [`poll_write_ready`] **except** that it - /// will not notify the current task when a new event is received. As such, - /// it is safe to call this function from outside of a task context. - /// - /// [`poll_write_ready`]: method@Self::poll_write_ready - pub fn take_write_ready(&self) -> io::Result<Option<mio::Ready>> { - self.poll_ready(Direction::Write, None) + pub(super) fn clear_readiness(&self, event: ReadyEvent) { + self.shared.clear_readiness(event); } /// Polls for events on the I/O resource's `direction` readiness stream. /// /// If called with a task context, notify the task when a new event is /// received. - fn poll_ready( + pub(super) fn poll_readiness( &self, + cx: &mut Context<'_>, direction: Direction, - cx: Option<&mut Context<'_>>, - ) -> io::Result<Option<mio::Ready>> { - let inner = match self.handle.inner() { - Some(inner) => inner, - None => return Err(io::Error::new(io::ErrorKind::Other, "reactor gone")), - }; - - // If the task should be notified about new events, ensure that it has - // been registered - if let Some(ref cx) = cx { - inner.register(self.address, direction, cx.waker().clone()) + ) -> Poll<io::Result<ReadyEvent>> { + if self.handle.inner().is_none() { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone"))); } - let mask = direction.mask(); - let mask_no_hup = (mask - platform::hup() - platform::error()).as_usize(); - - let sched = inner.io_dispatch.get(self.address).unwrap(); + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); + let ev = ready!(self.shared.poll_readiness(cx, direction)); + coop.made_progress(); + Poll::Ready(Ok(ev)) + } +} - // This consumes the current readiness state **except** for HUP and - // error. HUP and error are excluded because a) they are final states - // and never transitition out and b) both the read AND the write - // directions need to be able to obvserve these states. - // - // # Platform-specific behavior - // - // HUP and error readiness are platform-specific. On epoll platforms, - // HUP has specific conditions that must be met by both peers of a - // connection in order to be triggered. - // - // On epoll platforms, `EPOLLERR` is signaled through - // `UnixReady::error()` and is important to be observable by both read - // AND write. A specific case that `EPOLLERR` occurs is when the read - // end of a pipe is closed. When this occurs, a peer blocked by - // writing to the pipe should be notified. - let curr_ready = sched - .set_readiness(self.address, |curr| curr & (!mask_no_hup)) - .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address)); +cfg_io_readiness! { + impl Registration { + pub(super) async fn readiness(&self, interest: mio::Interest) -> io::Result<ReadyEvent> { + use std::future::Future; + use std::pin::Pin; - let mut ready = mask & mio::Ready::from_usize(curr_ready); + let fut = self.shared.readiness(interest); + pin!(fut); - if ready.is_empty() { - if let Some(cx) = cx { - // Update the task info - match direction { - Direction::Read => sched.reader.register_by_ref(cx.waker()), - Direction::Write => sched.writer.register_by_ref(cx.waker()), + crate::future::poll_fn(|cx| { + if self.handle.inner().is_none() { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "reactor gone"))); } - // Try again - let curr_ready = sched - .set_readiness(self.address, |curr| curr & (!mask_no_hup)) - .unwrap_or_else(|_| panic!("address {:?} no longer valid!", self.address)); - ready = mask & mio::Ready::from_usize(curr_ready); - } - } - - if ready.is_empty() { - Ok(None) - } else { - Ok(Some(ready)) + Pin::new(&mut fut).poll(cx).map(Ok) + }).await } } } - -unsafe impl Send for Registration {} -unsafe impl Sync for Registration {} - -impl Drop for Registration { - fn drop(&mut self) { - let inner = match self.handle.inner() { - Some(inner) => inner, - None => return, - }; - inner.drop_source(self.address); - } -} diff --git a/src/io/seek.rs b/src/io/seek.rs index e3b5bf6..e64205d 100644 --- a/src/io/seek.rs +++ b/src/io/seek.rs @@ -1,15 +1,23 @@ use crate::io::AsyncSeek; + +use pin_project_lite::pin_project; use std::future::Future; use std::io::{self, SeekFrom}; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -/// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Seek<'a, S: ?Sized> { - seek: &'a mut S, - pos: Option<SeekFrom>, +pin_project! { + /// Future for the [`seek`](crate::io::AsyncSeekExt::seek) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct Seek<'a, S: ?Sized> { + seek: &'a mut S, + pos: Option<SeekFrom>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } pub(crate) fn seek<S>(seek: &mut S, pos: SeekFrom) -> Seek<'_, S> @@ -19,6 +27,7 @@ where Seek { seek, pos: Some(pos), + _pin: PhantomPinned, } } @@ -28,29 +37,21 @@ where { type Output = io::Result<u64>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); match me.pos { - Some(pos) => match Pin::new(&mut me.seek).start_seek(cx, pos) { - Poll::Ready(Ok(())) => { - me.pos = None; - Pin::new(&mut me.seek).poll_complete(cx) + Some(pos) => { + // ensure no seek in progress + ready!(Pin::new(&mut *me.seek).poll_complete(cx))?; + match Pin::new(&mut *me.seek).start_seek(*pos) { + Ok(()) => { + *me.pos = None; + Pin::new(&mut *me.seek).poll_complete(cx) + } + Err(e) => Poll::Ready(Err(e)), } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - }, - None => Pin::new(&mut me.seek).poll_complete(cx), + } + None => Pin::new(&mut *me.seek).poll_complete(cx), } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Seek<'_, PhantomPinned>>(); - } -} diff --git a/src/io/split.rs b/src/io/split.rs index 134b937..fd3273e 100644 --- a/src/io/split.rs +++ b/src/io/split.rs @@ -4,9 +4,8 @@ //! To restore this read/write object from its `split::ReadHalf` and //! `split::WriteHalf` use `unsplit`. -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; -use bytes::{Buf, BufMut}; use std::cell::UnsafeCell; use std::fmt; use std::io; @@ -102,20 +101,11 @@ impl<T: AsyncRead> AsyncRead for ReadHalf<T> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { let mut inner = ready!(self.inner.poll_lock(cx)); inner.stream_pin().poll_read(cx, buf) } - - fn poll_read_buf<B: BufMut>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_read_buf(cx, buf) - } } impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { @@ -137,15 +127,6 @@ impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { let mut inner = ready!(self.inner.poll_lock(cx)); inner.stream_pin().poll_shutdown(cx) } - - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<Result<usize, io::Error>> { - let mut inner = ready!(self.inner.poll_lock(cx)); - inner.stream_pin().poll_write_buf(cx, buf) - } } impl<T> Inner<T> { diff --git a/src/io/stderr.rs b/src/io/stderr.rs index 99607dc..2f624fb 100644 --- a/src/io/stderr.rs +++ b/src/io/stderr.rs @@ -1,4 +1,5 @@ use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; use crate::io::AsyncWrite; use std::io; @@ -35,7 +36,7 @@ cfg_io_std! { /// ``` #[derive(Debug)] pub struct Stderr { - std: Blocking<std::io::Stderr>, + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stderr>>, } /// Constructs a new handle to the standard error of the current process. @@ -59,7 +60,7 @@ cfg_io_std! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut stderr = io::stdout(); + /// let mut stderr = io::stderr(); /// stderr.write_all(b"Print some error here.").await?; /// Ok(()) /// } @@ -67,7 +68,7 @@ cfg_io_std! { pub fn stderr() -> Stderr { let std = io::stderr(); Stderr { - std: Blocking::new(std), + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), } } } diff --git a/src/io/stdin.rs b/src/io/stdin.rs index 325b875..c9578f1 100644 --- a/src/io/stdin.rs +++ b/src/io/stdin.rs @@ -1,5 +1,5 @@ use crate::io::blocking::Blocking; -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -63,16 +63,11 @@ impl std::os::windows::io::AsRawHandle for Stdin { } impl AsyncRead for Stdin { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - // https://github.com/rust-lang/rust/blob/09c817eeb29e764cfc12d0a8d94841e3ffe34023/src/libstd/io/stdio.rs#L97 - false - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { Pin::new(&mut self.std).poll_read(cx, buf) } } diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs new file mode 100644 index 0000000..d21c842 --- /dev/null +++ b/src/io/stdio_common.rs @@ -0,0 +1,220 @@ +//! Contains utilities for stdout and stderr. +use crate::io::AsyncWrite; +use std::pin::Pin; +use std::task::{Context, Poll}; +/// # Windows +/// AsyncWrite adapter that finds last char boundary in given buffer and does not write the rest, +/// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF. +/// That's why, wrapped writer will always receive well-formed utf-8 bytes. +/// # Other platforms +/// passes data to `inner` as is +#[derive(Debug)] +pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { + inner: W, +} + +impl<W> SplitByUtf8BoundaryIfWindows<W> { + pub(crate) fn new(inner: W) -> Self { + Self { inner } + } +} + +// this constant is defined by Unicode standard. +const MAX_BYTES_PER_CHAR: usize = 4; + +// Subject for tweaking here +const MAGIC_CONST: usize = 8; + +impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W> +where + W: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll<Result<usize, std::io::Error>> { + // just a closure to avoid repetitive code + let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf); + + // 1. Only windows stdio can suffer from non-utf8. + // We also check for `test` so that we can write some tests + // for further code. Since `AsyncWrite` can always shrink + // buffer at its discretion, excessive (i.e. in tests) shrinking + // does not break correctness. + // 2. If buffer is small, it will not be shrinked. + // That's why, it's "textness" will not change, so we don't have + // to fixup it. + if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF + { + return call_inner(buf); + } + + buf = &buf[..crate::io::blocking::MAX_BUF]; + + // Now there are two possibilites. + // If caller gave is binary buffer, we **should not** shrink it + // anymore, because excessive shrinking hits performance. + // If caller gave as binary buffer, we **must** additionaly + // shrink it to strip incomplete char at the end of buffer. + // that's why check we will perform now is allowed to have + // false-positive. + + // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes. + // if they are (possibly incomplete) utf8, then we can be quite sure + // that input buffer was utf8. + + let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) { + Ok(_) => true, + Err(err) => { + let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to(); + incomplete_bytes < MAX_BYTES_PER_CHAR + } + }; + + if have_to_fix_up { + // We must pop several bytes at the end which form incomplete + // character. To achieve it, we exploit UTF8 encoding: + // for any code point, all bytes except first start with 0b10 prefix. + // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details + let trailing_incomplete_char_size = buf + .iter() + .rev() + .take(MAX_BYTES_PER_CHAR) + .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000) + .unwrap_or(0) + + 1; + buf = &buf[..buf.len() - trailing_incomplete_char_size]; + } + + call_inner(buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Result<(), std::io::Error>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[cfg(test)] +#[cfg(not(loom))] +mod tests { + use crate::io::AsyncWriteExt; + use std::io; + use std::pin::Pin; + use std::task::Context; + use std::task::Poll; + + const MAX_BUF: usize = 16 * 1024; + + struct TextMockWriter; + + impl crate::io::AsyncWrite for TextMockWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + assert!(std::str::from_utf8(buf).is_ok()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + struct LoggingMockWriter { + write_history: Vec<usize>, + } + + impl LoggingMockWriter { + fn new() -> Self { + LoggingMockWriter { + write_history: Vec::new(), + } + } + } + + impl crate::io::AsyncWrite for LoggingMockWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + assert!(buf.len() <= MAX_BUF); + self.write_history.push(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + } + + #[test] + fn test_splitter() { + let data = str::repeat("â–ˆ", MAX_BUF); + let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter); + let fut = async move { + wr.write_all(data.as_bytes()).await.unwrap(); + }; + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(fut); + } + + #[test] + fn test_pseudo_text() { + // In this test we write a piece of binary data, whose beginning is + // text though. We then validate that even in this corner case buffer + // was not shrinked too much. + let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; + let mut data: Vec<u8> = str::repeat("a", checked_count).into(); + data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); + let mut writer = LoggingMockWriter::new(); + let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer); + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(async { + splitter.write_all(&data).await.unwrap(); + }); + // Check that at most two writes were performed + assert!(writer.write_history.len() <= 2); + // Check that all has been written + assert_eq!( + writer.write_history.iter().copied().sum::<usize>(), + data.len() + ); + // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked + // from the buffer: one because it was outside of MAX_BUF boundary, and + // up to one "utf8 code point". + assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); + } +} diff --git a/src/io/stdout.rs b/src/io/stdout.rs index 5377993..a08ed01 100644 --- a/src/io/stdout.rs +++ b/src/io/stdout.rs @@ -1,6 +1,6 @@ use crate::io::blocking::Blocking; +use crate::io::stdio_common::SplitByUtf8BoundaryIfWindows; use crate::io::AsyncWrite; - use std::io; use std::pin::Pin; use std::task::Context; @@ -35,7 +35,7 @@ cfg_io_std! { /// ``` #[derive(Debug)] pub struct Stdout { - std: Blocking<std::io::Stdout>, + std: SplitByUtf8BoundaryIfWindows<Blocking<std::io::Stdout>>, } /// Constructs a new handle to the standard output of the current process. @@ -67,7 +67,7 @@ cfg_io_std! { pub fn stdout() -> Stdout { let std = io::stdout(); Stdout { - std: Blocking::new(std), + std: SplitByUtf8BoundaryIfWindows::new(Blocking::new(std)), } } } diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 1bfab90..9e87f2f 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -14,7 +14,7 @@ cfg_io_util! { /// Equivalent to: /// /// ```ignore - /// async fn read_until(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>; + /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>; /// ``` /// /// This function will read bytes from the underlying stream until the diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index e848a5d..0ab66c2 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -986,10 +986,12 @@ cfg_io_util! { /// /// All bytes read from this source will be appended to the specified /// buffer `buf`. This function will continuously call [`read()`] to - /// append more data to `buf` until [`read()`][read] returns `Ok(0)`. + /// append more data to `buf` until [`read()`] returns `Ok(0)`. /// /// If successful, the total number of bytes read is returned. /// + /// [`read()`]: AsyncReadExt::read + /// /// # Errors /// /// If a read error is encountered then the `read_to_end` operation @@ -1018,7 +1020,7 @@ cfg_io_util! { /// (See also the [`tokio::fs::read`] convenience function for reading from a /// file.) /// - /// [`tokio::fs::read`]: crate::fs::read::read + /// [`tokio::fs::read`]: fn@crate::fs::read fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self> where Self: Unpin, @@ -1065,7 +1067,7 @@ cfg_io_util! { /// (See also the [`crate::fs::read_to_string`] convenience function for /// reading from a file.) /// - /// [`crate::fs::read_to_string`]: crate::fs::read_to_string::read_to_string + /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self> where Self: Unpin, @@ -1078,7 +1080,11 @@ cfg_io_util! { /// This function returns a new instance of `AsyncRead` which will read /// at most `limit` bytes, after which it will always return EOF /// (`Ok(0)`). Any read errors will not count towards the number of - /// bytes read and future calls to [`read()`][read] may succeed. + /// bytes read and future calls to [`read()`] may succeed. + /// + /// [`read()`]: fn@crate::io::AsyncReadExt::read + /// + /// [read]: AsyncReadExt::read /// /// # Examples /// diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index c7a0f72..351900b 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -2,65 +2,73 @@ use crate::io::seek::{seek, Seek}; use crate::io::AsyncSeek; use std::io::SeekFrom; -/// An extension trait which adds utility methods to [`AsyncSeek`] types. -/// -/// As a convenience, this trait may be imported using the [`prelude`]: -/// -/// # Examples -/// -/// ``` -/// use std::io::{Cursor, SeekFrom}; -/// use tokio::prelude::*; -/// -/// #[tokio::main] -/// async fn main() -> io::Result<()> { -/// let mut cursor = Cursor::new(b"abcdefg"); -/// -/// // the `seek` method is defined by this trait -/// cursor.seek(SeekFrom::Start(3)).await?; -/// -/// let mut buf = [0; 1]; -/// let n = cursor.read(&mut buf).await?; -/// assert_eq!(n, 1); -/// assert_eq!(buf, [b'd']); -/// -/// Ok(()) -/// } -/// ``` -/// -/// See [module][crate::io] documentation for more details. -/// -/// [`AsyncSeek`]: AsyncSeek -/// [`prelude`]: crate::prelude -pub trait AsyncSeekExt: AsyncSeek { - /// Creates a future which will seek an IO object, and then yield the - /// new position in the object and the object itself. +cfg_io_util! { + /// An extension trait which adds utility methods to [`AsyncSeek`] types. /// - /// In the case of an error the buffer and the object will be discarded, with - /// the error yielded. + /// As a convenience, this trait may be imported using the [`prelude`]: /// /// # Examples /// - /// ```no_run - /// use tokio::fs::File; + /// ``` + /// use std::io::{Cursor, SeekFrom}; /// use tokio::prelude::*; /// - /// use std::io::SeekFrom; + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut cursor = Cursor::new(b"abcdefg"); + /// + /// // the `seek` method is defined by this trait + /// cursor.seek(SeekFrom::Start(3)).await?; /// - /// # async fn dox() -> std::io::Result<()> { - /// let mut file = File::open("foo.txt").await?; - /// file.seek(SeekFrom::Start(6)).await?; + /// let mut buf = [0; 1]; + /// let n = cursor.read(&mut buf).await?; + /// assert_eq!(n, 1); + /// assert_eq!(buf, [b'd']); /// - /// let mut contents = vec![0u8; 10]; - /// file.read_exact(&mut contents).await?; - /// # Ok(()) - /// # } + /// Ok(()) + /// } /// ``` - fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> - where - Self: Unpin, - { - seek(self, pos) + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncSeek`]: AsyncSeek + /// [`prelude`]: crate::prelude + pub trait AsyncSeekExt: AsyncSeek { + /// Creates a future which will seek an IO object, and then yield the + /// new position in the object and the object itself. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>; + /// ``` + /// + /// In the case of an error the buffer and the object will be discarded, with + /// the error yielded. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::prelude::*; + /// + /// use std::io::SeekFrom; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// file.seek(SeekFrom::Start(6)).await?; + /// + /// let mut contents = vec![0u8; 10]; + /// file.read_exact(&mut contents).await?; + /// # Ok(()) + /// # } + /// ``` + fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> + where + Self: Unpin, + { + seek(self, pos) + } } } diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index fa41097..e6ef5b2 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -119,6 +119,7 @@ cfg_io_util! { write(self, src) } + /// Writes a buffer into this writer, advancing the buffer's internal /// cursor. /// @@ -134,7 +135,7 @@ cfg_io_util! { /// internal cursor is advanced by the number of bytes written. A /// subsequent call to `write_buf` using the **same** `buf` value will /// resume from the point that the first call to `write_buf` completed. - /// A call to `write` represents *at most one* attempt to write to any + /// A call to `write_buf` represents *at most one* attempt to write to any /// wrapped object. /// /// # Return @@ -976,6 +977,8 @@ cfg_io_util! { /// no longer attempt to write to the stream. For example, the /// `TcpStream` implementation will issue a `shutdown(Write)` sys call. /// + /// [`flush`]: fn@crate::io::AsyncWriteExt::flush + /// /// # Examples /// /// ```no_run diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index a1c5990..271f61b 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -1,10 +1,8 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -use bytes::Buf; use pin_project_lite::pin_project; -use std::io::{self, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt}; @@ -44,21 +42,12 @@ impl<R: AsyncRead> BufReader<R> { /// Creates a new `BufReader` with the specified buffer capacity. pub fn with_capacity(capacity: usize, inner: R) -> Self { - unsafe { - let mut buffer = Vec::with_capacity(capacity); - buffer.set_len(capacity); - - { - // Convert to MaybeUninit - let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]); - inner.prepare_uninitialized_buffer(b); - } - Self { - inner, - buf: buffer.into_boxed_slice(), - pos: 0, - cap: 0, - } + let buffer = vec![0; capacity]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, } } @@ -110,25 +99,21 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.len() >= self.buf.len() { + if self.pos == self.cap && buf.remaining() >= self.buf.len() { let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); self.discard_buffer(); return Poll::Ready(res); } - let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?; - let nread = rem.read(buf)?; - self.consume(nread); - Poll::Ready(Ok(nread)) - } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) } } @@ -142,7 +127,9 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> { // to tell the compiler that the pos..cap slice is always valid. if *me.pos >= *me.cap { debug_assert!(*me.pos == *me.cap); - *me.cap = ready!(me.inner.poll_read(cx, me.buf))?; + let mut buf = ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + *me.cap = buf.filled().len(); *me.pos = 0; } Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) @@ -163,14 +150,6 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { self.get_pin_mut().poll_write(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.get_pin_mut().poll_write_buf(cx, buf) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.get_pin_mut().poll_flush(cx) } diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index a56a451..cc857e2 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -1,9 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -137,15 +136,10 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.project().inner.poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } } impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs index efd053e..5e3d4b7 100644 --- a/src/io/util/buf_writer.rs +++ b/src/io/util/buf_writer.rs @@ -1,10 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; use std::io::{self, Write}; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -147,15 +146,10 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.get_pin_mut().poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.get_ref().prepare_uninitialized_buffer(buf) - } } impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { diff --git a/src/io/util/chain.rs b/src/io/util/chain.rs index 8ba9194..84f37fc 100644 --- a/src/io/util/chain.rs +++ b/src/io/util/chain.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; @@ -84,26 +84,20 @@ where T: AsyncRead, U: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - if self.first.prepare_uninitialized_buffer(buf) { - return true; - } - if self.second.prepare_uninitialized_buffer(buf) { - return true; - } - false - } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { let me = self.project(); if !*me.done_first { - match ready!(me.first.poll_read(cx, buf)?) { - 0 if !buf.is_empty() => *me.done_first = true, - n => return Poll::Ready(Ok(n)), + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); } } me.second.poll_read(cx, buf) diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index 7bfe296..c5981cf 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -1,30 +1,25 @@ -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { - /// A future that asynchronously copies the entire contents of a reader into a - /// writer. - /// - /// This struct is generally created by calling [`copy`][copy]. Please - /// see the documentation of `copy()` for more details. - /// - /// [copy]: copy() - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct Copy<'a, R: ?Sized, W: ?Sized> { - reader: &'a mut R, - read_done: bool, - writer: &'a mut W, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, - } +/// A future that asynchronously copies the entire contents of a reader into a +/// writer. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +struct Copy<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + read_done: bool, + writer: &'a mut W, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +cfg_io_util! { /// Asynchronously copies the entire contents of a reader into a writer. /// /// This function returns a future that will continuously read data from @@ -58,7 +53,7 @@ cfg_io_util! { /// # Ok(()) /// # } /// ``` - pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W> + pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> where R: AsyncRead + Unpin + ?Sized, W: AsyncWrite + Unpin + ?Sized, @@ -71,7 +66,7 @@ cfg_io_util! { pos: 0, cap: 0, buf: vec![0; 2048].into_boxed_slice(), - } + }.await } } @@ -88,7 +83,9 @@ where // continue. if self.pos == self.cap && !self.read_done { let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + let mut buf = ReadBuf::new(&mut me.buf); + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; + let n = buf.filled().len(); if n == 0 { self.read_done = true; } else { @@ -122,14 +119,3 @@ where } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>(); - } -} diff --git a/src/io/util/copy_buf.rs b/src/io/util/copy_buf.rs new file mode 100644 index 0000000..6831580 --- /dev/null +++ b/src/io/util/copy_buf.rs @@ -0,0 +1,102 @@ +use crate::io::{AsyncBufRead, AsyncWrite}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// A future that asynchronously copies the entire contents of a reader into a + /// writer. + /// + /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please + /// see the documentation of `copy_buf()` for more details. + /// + /// [copy_buf]: copy_buf() + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + struct CopyBuf<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + amt: u64, + } + + /// Asynchronously copies the entire contents of a reader into a writer. + /// + /// This function returns a future that will continuously read data from + /// `reader` and then write it into `writer` in a streaming fashion until + /// `reader` returns EOF. + /// + /// On success, the total number of bytes that were copied from `reader` to + /// `writer` is returned. + /// + /// + /// # Errors + /// + /// The returned future will finish with an error will return an error + /// immediately if any call to `poll_fill_buf` or `poll_write` returns an + /// error. + /// + /// # Examples + /// + /// ``` + /// use tokio::io; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut reader: &[u8] = b"hello"; + /// let mut writer: Vec<u8> = vec![]; + /// + /// io::copy_buf(&mut reader, &mut writer).await?; + /// + /// assert_eq!(b"hello", &writer[..]); + /// # Ok(()) + /// # } + /// ``` + pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> + where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, + { + CopyBuf { + reader, + writer, + amt: 0, + }.await + } +} + +impl<R, W> Future for CopyBuf<'_, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<u64>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let me = &mut *self; + let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?; + if buffer.is_empty() { + ready!(Pin::new(&mut self.writer).poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + + let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?; + if i == 0 { + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + self.amt += i as u64; + Pin::new(&mut *self.reader).consume(i); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + use std::marker::PhantomPinned; + crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>(); + } +} diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs index 576058d..f964d18 100644 --- a/src/io/util/empty.rs +++ b/src/io/util/empty.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use std::fmt; use std::io; @@ -47,16 +47,13 @@ cfg_io_util! { } impl AsyncRead for Empty { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - _: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(Ok(0)) + _: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) } } diff --git a/src/io/util/flush.rs b/src/io/util/flush.rs index 534a516..88d60b8 100644 --- a/src/io/util/flush.rs +++ b/src/io/util/flush.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to fully flush an I/O object. /// /// Created by the [`AsyncWriteExt::flush`][flush] function. /// [flush]: crate::io::AsyncWriteExt::flush #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Flush<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Flush { a } + Flush { + a, + _pin: PhantomPinned, + } } impl<A> Future for Flush<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); Pin::new(&mut *me.a).poll_flush(cx) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Flush<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index ee27400..b41f04a 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -83,8 +83,7 @@ impl<R> Lines<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_line( + fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<String>>> { diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs new file mode 100644 index 0000000..e91a932 --- /dev/null +++ b/src/io/util/mem.rs @@ -0,0 +1,223 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct DuplexStream { + read: Arc<Mutex<Pipe>>, + write: Arc<Mutex<Pipe>>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option<Waker>, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option<Waker>, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.read.lock()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.write.lock()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().close(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close(&mut self) { + self.is_closed = true; + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.remaining()); + buf.put_slice(&self.buffer[..max]); + self.buffer.advance(max); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(())) + } else if self.is_closed { + Poll::Ready(Ok(())) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + self.close(); + Poll::Ready(Ok(())) + } +} diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs index c4754ab..e75ea03 100644 --- a/src/io/util/mod.rs +++ b/src/io/util/mod.rs @@ -25,7 +25,10 @@ cfg_io_util! { mod chain; mod copy; - pub use copy::{copy, Copy}; + pub use copy::copy; + + mod copy_buf; + pub use copy_buf::copy_buf; mod empty; pub use empty::{empty, Empty}; @@ -35,6 +38,9 @@ cfg_io_util! { mod lines; pub use lines::Lines; + mod mem; + pub use mem::{duplex, DuplexStream}; + mod read; mod read_buf; mod read_exact; @@ -60,11 +66,6 @@ cfg_io_util! { mod split; pub use split::Split; - cfg_stream! { - mod stream_reader; - pub use stream_reader::{stream_reader, StreamReader}; - } - mod take; pub use take::Take; diff --git a/src/io/util/read.rs b/src/io/util/read.rs index a8ca370..edc9d5a 100644 --- a/src/io/util/read.rs +++ b/src/io/util/read.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,10 +17,14 @@ pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R> where R: AsyncRead + Unpin + ?Sized, { - Read { reader, buf } + Read { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// A future which can be used to easily read available number of bytes to fill /// a buffer. /// @@ -28,6 +34,9 @@ cfg_io_util! { pub struct Read<'a, R: ?Sized> { reader: &'a mut R, buf: &'a mut [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -37,19 +46,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read(cx, me.buf) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Read<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + let mut buf = ReadBuf::new(*me.buf); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?; + Poll::Ready(Ok(buf.filled().len())) } } diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs index 6ee3d24..696deef 100644 --- a/src/io/util/read_buf.rs +++ b/src/io/util/read_buf.rs @@ -1,8 +1,10 @@ use crate::io::AsyncRead; use bytes::BufMut; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; @@ -11,16 +13,22 @@ where R: AsyncRead + Unpin, B: BufMut, { - ReadBuf { reader, buf } + ReadBuf { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadBuf<'a, R, B> { reader: &'a mut R, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -31,8 +39,34 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + use crate::io::ReadBuf; + use std::mem::MaybeUninit; + + let me = self.project(); + + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = me.buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs index 86b8412..1e8150e 100644 --- a/src/io/util/read_exact.rs +++ b/src/io/util/read_exact.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -17,12 +19,12 @@ where { ReadExact { reader, - buf, - pos: 0, + buf: ReadBuf::new(buf), + _pin: PhantomPinned, } } -cfg_io_util! { +pin_project! { /// Creates a future which will read exactly enough bytes to fill `buf`, /// returning an error if EOF is hit sooner. /// @@ -31,8 +33,10 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadExact<'a, A: ?Sized> { reader: &'a mut A, - buf: &'a mut [u8], - pos: usize, + buf: ReadBuf<'a>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -46,32 +50,20 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let mut me = self.project(); + loop { // if our buffer is empty, then we need to read some data to continue. - if self.pos < self.buf.len() { - let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?; - me.pos += n; - if n == 0 { + let rem = me.buf.remaining(); + if rem != 0 { + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + if me.buf.remaining() == rem { return Err(eof()).into(); } - } - - if self.pos >= self.buf.len() { - return Poll::Ready(Ok(self.pos)); + } else { + return Poll::Ready(Ok(me.buf.capacity())); } } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadExact<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs index 9d37dc7..5b9fb7b 100644 --- a/src/io/util/read_int.rs +++ b/src/io/util/read_int.rs @@ -1,10 +1,11 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use bytes::Buf; use pin_project_lite::pin_project; use std::future::Future; use std::io; use std::io::ErrorKind::UnexpectedEof; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -16,11 +17,15 @@ macro_rules! reader { ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] src: R, buf: [u8; $bytes], read: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -30,6 +35,7 @@ macro_rules! reader { src, buf: [0; $bytes], read: 0, + _pin: PhantomPinned, } } } @@ -48,17 +54,19 @@ macro_rules! reader { } while *me.read < $bytes as u8 { - *me.read += match me - .src - .as_mut() - .poll_read(cx, &mut me.buf[*me.read as usize..]) - { + let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]); + + *me.read += match me.src.as_mut().poll_read(cx, &mut buf) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => { - return Poll::Ready(Err(UnexpectedEof.into())); + Poll::Ready(Ok(())) => { + let n = buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + n as u8 } - Poll::Ready(Ok(n)) => n as u8, }; } @@ -75,15 +83,22 @@ macro_rules! reader8 { pin_project! { /// Future returned from `read_u8` #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] reader: R, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<R> $name<R> { pub(crate) fn new(reader: R) -> $name<R> { - $name { reader } + $name { + reader, + _pin: PhantomPinned, + } } } @@ -97,12 +112,17 @@ macro_rules! reader8 { let me = self.project(); let mut buf = [0; 1]; - match me.reader.poll_read(cx, &mut buf[..]) { + let mut buf = ReadBuf::new(&mut buf); + match me.reader.poll_read(cx, &mut buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())), - Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)), - Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Ok(())) => { + if buf.filled().len() == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + Poll::Ready(Ok(buf.filled()[0] as $ty)) + } } } } diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs index d625a76..d38ffaf 100644 --- a/src/io/util/read_line.rs +++ b/src/io/util/read_line.rs @@ -1,26 +1,32 @@ use crate::io::util::read_until::read_until_internal; use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; +use std::string::FromUtf8Error; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - /// This is the buffer we were provided. It will be replaced with an empty string - /// while reading to postpone utf-8 handling until after reading. + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. output: &'a mut String, - /// The actual allocation of the string is moved into a vector instead. + // The actual allocation of the string is moved into this vector instead. buf: Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -33,6 +39,7 @@ where buf: mem::replace(string, String::new()).into_bytes(), output: string, read: 0, + _pin: PhantomPinned, } } @@ -42,31 +49,33 @@ fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_re *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); } -pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( - reader: Pin<&mut R>, - cx: &mut Context<'_>, +/// This handles the various failure cases and puts the string back into `output`. +/// +/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line` +/// disagree on what should happen when an IO error occurs. +pub(super) fn finish_string_read( + io_res: io::Result<usize>, + utf8_res: Result<String, FromUtf8Error>, + read: usize, output: &mut String, - buf: &mut Vec<u8>, - read: &mut usize, + truncate_on_io_error: bool, ) -> Poll<io::Result<usize>> { - let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); - let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - - // At this point both buf and output are empty. The allocation is in utf8_res. - - debug_assert!(buf.is_empty()); match (io_res, utf8_res) { (Ok(num_bytes), Ok(string)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); *output = string; Poll::Ready(Ok(num_bytes)) } (Err(io_err), Ok(string)) => { *output = string; + if truncate_on_io_error { + let original_len = output.len() - read; + output.truncate(original_len); + } Poll::Ready(Err(io_err)) } (Ok(num_bytes), Err(utf8_err)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); put_back_original_data(output, utf8_err.into_bytes(), num_bytes); Poll::Ready(Err(io::Error::new( @@ -75,35 +84,36 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( ))) } (Err(io_err), Err(utf8_err)) => { - put_back_original_data(output, utf8_err.into_bytes(), *read); + put_back_original_data(output, utf8_err.into_bytes(), read); Poll::Ready(Err(io_err)) } } } -impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { - type Output = io::Result<usize>; +pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( + reader: Pin<&mut R>, + cx: &mut Context<'_>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, +) -> Poll<io::Result<usize>> { + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - output, - buf, - read, - } = &mut *self; + // At this point both buf and output are empty. The allocation is in utf8_res. - read_line_internal(Pin::new(reader), cx, output, buf, read) - } + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, false) } -#[cfg(test)] -mod tests { - use super::*; +impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadLine<'_, PhantomPinned>>(); + read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read) } } diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs index a2cd99b..a974625 100644 --- a/src/io/util/read_to_end.rs +++ b/src/io/util/read_to_end.rs @@ -1,92 +1,105 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; -use std::mem::MaybeUninit; +use std::marker::PhantomPinned; +use std::mem::{self, MaybeUninit}; use std::pin::Pin; use std::task::{Context, Poll}; -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub struct ReadToEnd<'a, R: ?Sized> { - reader: &'a mut R, - buf: &'a mut Vec<u8>, - start_len: usize, +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadToEnd<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } -pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, R> +pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R> where R: AsyncRead + Unpin + ?Sized, { - let start_len = buf.len(); ReadToEnd { reader, - buf, - start_len, + buf: buffer, + read: 0, + _pin: PhantomPinned, } } -struct Guard<'a> { - buf: &'a mut Vec<u8>, - len: usize, -} - -impl Drop for Guard<'_> { - fn drop(&mut self) { - unsafe { - self.buf.set_len(self.len); +pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( + buf: &mut Vec<u8>, + mut reader: Pin<&mut R>, + num_read: &mut usize, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + loop { + // safety: The caller promised to prepare the buffer. + let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx)); + match ret { + Err(err) => return Poll::Ready(Err(err)), + Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))), + Ok(num) => { + *num_read += num; + } } } } -// This uses an adaptive system to extend the vector when it fills. We want to -// avoid paying to allocate and zero a huge chunk of memory if the reader only -// has 4 bytes while still making large reads if the reader does have a ton -// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every -// time is 4,500 times (!) slower than this if the reader has a very small -// amount of data to return. -// -// Because we're extending the buffer with uninitialized data for trusted -// readers, we need to make sure to truncate that if any of this panics. -pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( - mut rd: Pin<&mut R>, - cx: &mut Context<'_>, +/// Tries to read from the provided AsyncRead. +/// +/// The length of the buffer is increased by the number of bytes read. +fn poll_read_to_end<R: AsyncRead + ?Sized>( buf: &mut Vec<u8>, - start_len: usize, + read: Pin<&mut R>, + cx: &mut Context<'_>, ) -> Poll<io::Result<usize>> { - let mut g = Guard { - len: buf.len(), - buf, - }; - let ret; - loop { - if g.len == g.buf.len() { - unsafe { - g.buf.reserve(32); - let capacity = g.buf.capacity(); - g.buf.set_len(capacity); + // This uses an adaptive system to extend the vector when it fills. We want to + // avoid paying to allocate and zero a huge chunk of memory if the reader only + // has 4 bytes while still making large reads if the reader does have a ton + // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every + // time is 4,500 times (!) slower than this if the reader has a very small + // amount of data to return. + reserve(buf, 32); - let b = &mut *(&mut g.buf[g.len..] as *mut [u8] as *mut [MaybeUninit<u8>]); + let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); - rd.prepare_uninitialized_buffer(b); - } - } + ready!(read.poll_read(cx, &mut unused_capacity))?; - match ready!(rd.as_mut().poll_read(cx, &mut g.buf[g.len..])) { - Ok(0) => { - ret = Poll::Ready(Ok(g.len - start_len)); - break; - } - Ok(n) => g.len += n, - Err(e) => { - ret = Poll::Ready(Err(e)); - break; - } - } + let n = unused_capacity.filled().len(); + let new_len = buf.len() + n; + + // This should no longer even be possible in safe Rust. An implementor + // would need to have unsafely *replaced* the buffer inside `ReadBuf`, + // which... yolo? + assert!(new_len <= buf.capacity()); + unsafe { + buf.set_len(new_len); } + Poll::Ready(Ok(n)) +} - ret +/// Allocates more memory and ensures that the unused capacity is prepared for use +/// with the `AsyncRead`. +fn reserve(buf: &mut Vec<u8>, bytes: usize) { + if buf.capacity() - buf.len() >= bytes { + return; + } + buf.reserve(bytes); +} + +/// Returns the unused capacity of the provided vector. +fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] { + let uninit = bytes::BufMut::bytes_mut(buf); + unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) } } impl<A> Future for ReadToEnd<'_, A> @@ -95,19 +108,9 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let this = &mut *self; - read_to_end_internal(Pin::new(&mut this.reader), cx, this.buf, this.start_len) - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToEnd<'_, PhantomPinned>>(); + read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx) } } diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs index cab0505..e463203 100644 --- a/src/io/util/read_to_string.rs +++ b/src/io/util/read_to_string.rs @@ -1,58 +1,71 @@ +use crate::io::util::read_line::finish_string_read; use crate::io::util::read_to_end::read_to_end_internal; use crate::io::AsyncRead; +use pin_project_lite::pin_project; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; -cfg_io_util! { +pin_project! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadToString<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec<u8>, - start_len: usize, + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + // The actual allocation of the string is moved into this vector instead. + buf: Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } -pub(crate) fn read_to_string<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadToString<'a, R> +pub(crate) fn read_to_string<'a, R>( + reader: &'a mut R, + string: &'a mut String, +) -> ReadToString<'a, R> where R: AsyncRead + ?Sized + Unpin, { - let start_len = buf.len(); + let buf = mem::replace(string, String::new()).into_bytes(); ReadToString { reader, - bytes: mem::replace(buf, String::new()).into_bytes(), buf, - start_len, + output: string, + read: 0, + _pin: PhantomPinned, } } -fn read_to_string_internal<R: AsyncRead + ?Sized>( +/// # Safety +/// +/// Before first calling this method, the unused capacity must have been +/// prepared for use with the provided AsyncRead. This can be done using the +/// `prepare_buffer` function in `read_to_end.rs`. +unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>( reader: Pin<&mut R>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec<u8>, - start_len: usize, ) -> Poll<io::Result<usize>> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; - match String::from_utf8(mem::replace(bytes, Vec::new())) { - Ok(string) => { - debug_assert!(buf.is_empty()); - *buf = string; - Poll::Ready(Ok(ret)) - } - Err(e) => { - *bytes = e.into_bytes(); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - "stream did not contain valid UTF-8", - ))) - } - } + let io_res = ready!(read_to_end_internal(buf, reader, read, cx)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, true) } impl<A> Future for ReadToString<'_, A> @@ -61,31 +74,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - buf, - bytes, - start_len, - } = &mut *self; - let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); - if let Poll::Ready(Err(_)) = ret { - // Put back the original string. - bytes.truncate(*start_len); - **buf = String::from_utf8(mem::replace(bytes, Vec::new())) - .expect("original string no longer utf-8"); - } - ret - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToString<'_, PhantomPinned>>(); + // safety: The constructor of ReadToString called `prepare_buffer`. + unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) } } } diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs index 78dac8c..3599cff 100644 --- a/src/io/util/read_until.rs +++ b/src/io/util/read_until.rs @@ -1,12 +1,14 @@ use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. /// The delimeter is included in the resulting vector. #[derive(Debug)] @@ -15,9 +17,12 @@ cfg_io_util! { reader: &'a mut R, delimeter: u8, buf: &'a mut Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -34,6 +39,7 @@ where delimeter, buf, read: 0, + _pin: PhantomPinned, } } @@ -66,24 +72,8 @@ pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - delimeter, - buf, - read, - } = &mut *self; - read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadUntil<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read) } } diff --git a/src/io/util/repeat.rs b/src/io/util/repeat.rs index eeef7cc..1142765 100644 --- a/src/io/util/repeat.rs +++ b/src/io/util/repeat.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -47,19 +47,17 @@ cfg_io_util! { } impl AsyncRead for Repeat { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - for byte in &mut *buf { - *byte = self.byte; + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // TODO: could be faster, but should we unsafe it? + while buf.remaining() != 0 { + buf.put_slice(&[self.byte]); } - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/shutdown.rs b/src/io/util/shutdown.rs index 33ac0ac..6d30b00 100644 --- a/src/io/util/shutdown.rs +++ b/src/io/util/shutdown.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to shutdown an I/O object. /// /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. /// [shutdown]: crate::io::AsyncWriteExt::shutdown + #[must_use = "futures do nothing unless you `.await` or poll them"] #[derive(Debug)] pub struct Shutdown<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Shutdown { a } + Shutdown { + a, + _pin: PhantomPinned, + } } impl<A> Future for Shutdown<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; - Pin::new(&mut *me.a).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Shutdown<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(me.a).poll_shutdown(cx) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index f552ed5..492e26a 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -65,8 +65,7 @@ impl<R> Split<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_segment( + fn poll_next_segment( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<Vec<u8>>>> { diff --git a/src/io/util/stream_reader.rs b/src/io/util/stream_reader.rs deleted file mode 100644 index b98f8bd..0000000 --- a/src/io/util/stream_reader.rs +++ /dev/null @@ -1,184 +0,0 @@ -use crate::io::{AsyncBufRead, AsyncRead}; -use crate::stream::Stream; -use bytes::{Buf, BufMut}; -use pin_project_lite::pin_project; -use std::io; -use std::mem::MaybeUninit; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pin_project! { - /// Convert a stream of byte chunks into an [`AsyncRead`]. - /// - /// This type is usually created using the [`stream_reader`] function. - /// - /// [`AsyncRead`]: crate::io::AsyncRead - /// [`stream_reader`]: crate::io::stream_reader - #[derive(Debug)] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] - #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] - pub struct StreamReader<S, B> { - #[pin] - inner: S, - chunk: Option<B>, - } -} - -/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead). -/// -/// # Example -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::{stream_reader, AsyncReadExt}; -/// # #[tokio::main] -/// # async fn main() -> std::io::Result<()> { -/// -/// // Create a stream from an iterator. -/// let stream = tokio::stream::iter(vec![ -/// Ok(Bytes::from_static(&[0, 1, 2, 3])), -/// Ok(Bytes::from_static(&[4, 5, 6, 7])), -/// Ok(Bytes::from_static(&[8, 9, 10, 11])), -/// ]); -/// -/// // Convert it to an AsyncRead. -/// let mut read = stream_reader(stream); -/// -/// // Read five bytes from the stream. -/// let mut buf = [0; 5]; -/// read.read_exact(&mut buf).await?; -/// assert_eq!(buf, [0, 1, 2, 3, 4]); -/// -/// // Read the rest of the current chunk. -/// assert_eq!(read.read(&mut buf).await?, 3); -/// assert_eq!(&buf[..3], [5, 6, 7]); -/// -/// // Read the next chunk. -/// assert_eq!(read.read(&mut buf).await?, 4); -/// assert_eq!(&buf[..4], [8, 9, 10, 11]); -/// -/// // We have now reached the end. -/// assert_eq!(read.read(&mut buf).await?, 0); -/// -/// # Ok(()) -/// # } -/// ``` -#[cfg_attr(docsrs, doc(cfg(feature = "stream")))] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub fn stream_reader<S, B>(stream: S) -> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - StreamReader::new(stream) -} - -impl<S, B> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - /// Convert the provided stream into an `AsyncRead`. - fn new(stream: S) -> Self { - Self { - inner: stream, - chunk: None, - } - } - /// Do we have a chunk and is it non-empty? - fn has_chunk(self: Pin<&mut Self>) -> bool { - if let Some(chunk) = self.project().chunk { - chunk.remaining() > 0 - } else { - false - } - } -} - -impl<S, B> AsyncRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.len()); - (&mut buf[..len]).copy_from_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - fn poll_read_buf<BM: BufMut>( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut BM, - ) -> Poll<io::Result<usize>> - where - Self: Sized, - { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.remaining_mut()); - buf.put_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } -} - -impl<S, B> AsyncBufRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { - loop { - if self.as_mut().has_chunk() { - // This unwrap is very sad, but it can't be avoided. - let buf = self.project().chunk.as_ref().unwrap().bytes(); - return Poll::Ready(Ok(buf)); - } else { - match self.as_mut().project().inner.poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - // Go around the loop in case the chunk is empty. - *self.as_mut().project().chunk = Some(chunk); - } - Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), - Poll::Ready(None) => return Poll::Ready(Ok(&[])), - Poll::Pending => return Poll::Pending, - } - } - } - } - fn consume(self: Pin<&mut Self>, amt: usize) { - if amt > 0 { - self.project() - .chunk - .as_mut() - .expect("No chunk present") - .advance(amt); - } - } -} diff --git a/src/io/util/take.rs b/src/io/util/take.rs index 5d6bd90..b5e90c9 100644 --- a/src/io/util/take.rs +++ b/src/io/util/take.rs @@ -1,7 +1,6 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, io}; @@ -76,24 +75,27 @@ impl<R: AsyncRead> Take<R> { } impl<R: AsyncRead> AsyncRead for Take<R> { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<usize, io::Error>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { if self.limit_ == 0 { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } let me = self.project(); - let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize; - let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?; + let mut b = buf.take(*me.limit_ as usize); + ready!(me.inner.poll_read(cx, &mut b))?; + let n = b.filled().len(); + + // We need to update the original ReadBuf + unsafe { + buf.assume_init(n); + } + buf.advance(n); *me.limit_ -= n as u64; - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/write.rs b/src/io/util/write.rs index 433a421..92169eb 100644 --- a/src/io/util/write.rs +++ b/src/io/util/write.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Write<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +26,11 @@ pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W> where W: AsyncWrite + Unpin + ?Sized, { - Write { writer, buf } + Write { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for Write<'_, W> @@ -30,8 +39,8 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); Pin::new(&mut *me.writer).poll_write(cx, me.buf) } } diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs index 898006c..e59d41e 100644 --- a/src/io/util/write_all.rs +++ b/src/io/util/write_all.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteAll<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -19,7 +24,11 @@ pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, where W: AsyncWrite + Unpin + ?Sized, { - WriteAll { writer, buf } + WriteAll { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for WriteAll<'_, W> @@ -28,13 +37,13 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); while !me.buf.is_empty() { - let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?; + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; { - let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n); - me.buf = rest; + let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n); + *me.buf = rest; } if n == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); @@ -44,14 +53,3 @@ where Poll::Ready(Ok(())) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<WriteAll<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs index cedfde6..1310e5c 100644 --- a/src/io/util/write_buf.rs +++ b/src/io/util/write_buf.rs @@ -1,18 +1,22 @@ use crate::io::AsyncWrite; use bytes::Buf; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteBuf<'a, W, B> { writer: &'a mut W, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -23,7 +27,11 @@ where W: AsyncWrite + Unpin, B: Buf, { - WriteBuf { writer, buf } + WriteBuf { + writer, + buf, + _pin: PhantomPinned, + } } impl<W, B> Future for WriteBuf<'_, W, B> @@ -33,8 +41,15 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + + if !me.buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?; + me.buf.advance(n); + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs index ee992de..13bc191 100644 --- a/src/io/util/write_int.rs +++ b/src/io/util/write_int.rs @@ -4,6 +4,7 @@ use bytes::BufMut; use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,20 +16,25 @@ macro_rules! writer { ($name:ident, $ty:ty, $writer:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, buf: [u8; $bytes], written: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(w: W, value: $ty) -> Self { - let mut writer = $name { + let mut writer = Self { buf: [0; $bytes], written: 0, dst: w, + _pin: PhantomPinned, }; BufMut::$writer(&mut &mut writer.buf[..], value); writer @@ -72,16 +78,24 @@ macro_rules! writer8 { ($name:ident, $ty:ty) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, byte: $ty, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(dst: W, byte: $ty) -> Self { - Self { dst, byte } + Self { + dst, + byte, + _pin: PhantomPinned, + } } } |