diff options
Diffstat (limited to 'src/io/util')
32 files changed, 869 insertions, 695 deletions
diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 1bfab90..9e87f2f 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -14,7 +14,7 @@ cfg_io_util! { /// Equivalent to: /// /// ```ignore - /// async fn read_until(&mut self, buf: &mut Vec<u8>) -> io::Result<usize>; + /// async fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> io::Result<usize>; /// ``` /// /// This function will read bytes from the underlying stream until the diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index e848a5d..0ab66c2 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -986,10 +986,12 @@ cfg_io_util! { /// /// All bytes read from this source will be appended to the specified /// buffer `buf`. This function will continuously call [`read()`] to - /// append more data to `buf` until [`read()`][read] returns `Ok(0)`. + /// append more data to `buf` until [`read()`] returns `Ok(0)`. /// /// If successful, the total number of bytes read is returned. /// + /// [`read()`]: AsyncReadExt::read + /// /// # Errors /// /// If a read error is encountered then the `read_to_end` operation @@ -1018,7 +1020,7 @@ cfg_io_util! { /// (See also the [`tokio::fs::read`] convenience function for reading from a /// file.) /// - /// [`tokio::fs::read`]: crate::fs::read::read + /// [`tokio::fs::read`]: fn@crate::fs::read fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self> where Self: Unpin, @@ -1065,7 +1067,7 @@ cfg_io_util! { /// (See also the [`crate::fs::read_to_string`] convenience function for /// reading from a file.) /// - /// [`crate::fs::read_to_string`]: crate::fs::read_to_string::read_to_string + /// [`crate::fs::read_to_string`]: fn@crate::fs::read_to_string fn read_to_string<'a>(&'a mut self, dst: &'a mut String) -> ReadToString<'a, Self> where Self: Unpin, @@ -1078,7 +1080,11 @@ cfg_io_util! { /// This function returns a new instance of `AsyncRead` which will read /// at most `limit` bytes, after which it will always return EOF /// (`Ok(0)`). Any read errors will not count towards the number of - /// bytes read and future calls to [`read()`][read] may succeed. + /// bytes read and future calls to [`read()`] may succeed. + /// + /// [`read()`]: fn@crate::io::AsyncReadExt::read + /// + /// [read]: AsyncReadExt::read /// /// # Examples /// diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index c7a0f72..351900b 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -2,65 +2,73 @@ use crate::io::seek::{seek, Seek}; use crate::io::AsyncSeek; use std::io::SeekFrom; -/// An extension trait which adds utility methods to [`AsyncSeek`] types. -/// -/// As a convenience, this trait may be imported using the [`prelude`]: -/// -/// # Examples -/// -/// ``` -/// use std::io::{Cursor, SeekFrom}; -/// use tokio::prelude::*; -/// -/// #[tokio::main] -/// async fn main() -> io::Result<()> { -/// let mut cursor = Cursor::new(b"abcdefg"); -/// -/// // the `seek` method is defined by this trait -/// cursor.seek(SeekFrom::Start(3)).await?; -/// -/// let mut buf = [0; 1]; -/// let n = cursor.read(&mut buf).await?; -/// assert_eq!(n, 1); -/// assert_eq!(buf, [b'd']); -/// -/// Ok(()) -/// } -/// ``` -/// -/// See [module][crate::io] documentation for more details. -/// -/// [`AsyncSeek`]: AsyncSeek -/// [`prelude`]: crate::prelude -pub trait AsyncSeekExt: AsyncSeek { - /// Creates a future which will seek an IO object, and then yield the - /// new position in the object and the object itself. +cfg_io_util! { + /// An extension trait which adds utility methods to [`AsyncSeek`] types. /// - /// In the case of an error the buffer and the object will be discarded, with - /// the error yielded. + /// As a convenience, this trait may be imported using the [`prelude`]: /// /// # Examples /// - /// ```no_run - /// use tokio::fs::File; + /// ``` + /// use std::io::{Cursor, SeekFrom}; /// use tokio::prelude::*; /// - /// use std::io::SeekFrom; + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut cursor = Cursor::new(b"abcdefg"); + /// + /// // the `seek` method is defined by this trait + /// cursor.seek(SeekFrom::Start(3)).await?; /// - /// # async fn dox() -> std::io::Result<()> { - /// let mut file = File::open("foo.txt").await?; - /// file.seek(SeekFrom::Start(6)).await?; + /// let mut buf = [0; 1]; + /// let n = cursor.read(&mut buf).await?; + /// assert_eq!(n, 1); + /// assert_eq!(buf, [b'd']); /// - /// let mut contents = vec![0u8; 10]; - /// file.read_exact(&mut contents).await?; - /// # Ok(()) - /// # } + /// Ok(()) + /// } /// ``` - fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> - where - Self: Unpin, - { - seek(self, pos) + /// + /// See [module][crate::io] documentation for more details. + /// + /// [`AsyncSeek`]: AsyncSeek + /// [`prelude`]: crate::prelude + pub trait AsyncSeekExt: AsyncSeek { + /// Creates a future which will seek an IO object, and then yield the + /// new position in the object and the object itself. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn seek(&mut self, pos: SeekFrom) -> io::Result<u64>; + /// ``` + /// + /// In the case of an error the buffer and the object will be discarded, with + /// the error yielded. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::fs::File; + /// use tokio::prelude::*; + /// + /// use std::io::SeekFrom; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut file = File::open("foo.txt").await?; + /// file.seek(SeekFrom::Start(6)).await?; + /// + /// let mut contents = vec![0u8; 10]; + /// file.read_exact(&mut contents).await?; + /// # Ok(()) + /// # } + /// ``` + fn seek(&mut self, pos: SeekFrom) -> Seek<'_, Self> + where + Self: Unpin, + { + seek(self, pos) + } } } diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index fa41097..e6ef5b2 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -119,6 +119,7 @@ cfg_io_util! { write(self, src) } + /// Writes a buffer into this writer, advancing the buffer's internal /// cursor. /// @@ -134,7 +135,7 @@ cfg_io_util! { /// internal cursor is advanced by the number of bytes written. A /// subsequent call to `write_buf` using the **same** `buf` value will /// resume from the point that the first call to `write_buf` completed. - /// A call to `write` represents *at most one* attempt to write to any + /// A call to `write_buf` represents *at most one* attempt to write to any /// wrapped object. /// /// # Return @@ -976,6 +977,8 @@ cfg_io_util! { /// no longer attempt to write to the stream. For example, the /// `TcpStream` implementation will issue a `shutdown(Write)` sys call. /// + /// [`flush`]: fn@crate::io::AsyncWriteExt::flush + /// /// # Examples /// /// ```no_run diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index a1c5990..271f61b 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -1,10 +1,8 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -use bytes::Buf; use pin_project_lite::pin_project; -use std::io::{self, Read}; -use std::mem::MaybeUninit; +use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, fmt}; @@ -44,21 +42,12 @@ impl<R: AsyncRead> BufReader<R> { /// Creates a new `BufReader` with the specified buffer capacity. pub fn with_capacity(capacity: usize, inner: R) -> Self { - unsafe { - let mut buffer = Vec::with_capacity(capacity); - buffer.set_len(capacity); - - { - // Convert to MaybeUninit - let b = &mut *(&mut buffer[..] as *mut [u8] as *mut [MaybeUninit<u8>]); - inner.prepare_uninitialized_buffer(b); - } - Self { - inner, - buf: buffer.into_boxed_slice(), - pos: 0, - cap: 0, - } + let buffer = vec![0; capacity]; + Self { + inner, + buf: buffer.into_boxed_slice(), + pos: 0, + cap: 0, } } @@ -110,25 +99,21 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { // If we don't have any buffered data and we're doing a massive read // (larger than our internal buffer), bypass our internal buffer // entirely. - if self.pos == self.cap && buf.len() >= self.buf.len() { + if self.pos == self.cap && buf.remaining() >= self.buf.len() { let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); self.discard_buffer(); return Poll::Ready(res); } - let mut rem = ready!(self.as_mut().poll_fill_buf(cx))?; - let nread = rem.read(buf)?; - self.consume(nread); - Poll::Ready(Ok(nread)) - } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + let rem = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = std::cmp::min(rem.len(), buf.remaining()); + buf.put_slice(&rem[..amt]); + self.consume(amt); + Poll::Ready(Ok(())) } } @@ -142,7 +127,9 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> { // to tell the compiler that the pos..cap slice is always valid. if *me.pos >= *me.cap { debug_assert!(*me.pos == *me.cap); - *me.cap = ready!(me.inner.poll_read(cx, me.buf))?; + let mut buf = ReadBuf::new(me.buf); + ready!(me.inner.poll_read(cx, &mut buf))?; + *me.cap = buf.filled().len(); *me.pos = 0; } Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) @@ -163,14 +150,6 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { self.get_pin_mut().poll_write(cx, buf) } - fn poll_write_buf<B: Buf>( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll<io::Result<usize>> { - self.get_pin_mut().poll_write_buf(cx, buf) - } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.get_pin_mut().poll_flush(cx) } diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index a56a451..cc857e2 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -1,9 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -137,15 +136,10 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.project().inner.poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } } impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs index efd053e..5e3d4b7 100644 --- a/src/io/util/buf_writer.rs +++ b/src/io/util/buf_writer.rs @@ -1,10 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; use std::io::{self, Write}; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; @@ -147,15 +146,10 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { self.get_pin_mut().poll_read(cx, buf) } - - // we can't skip unconditionally because of the large buffer case in read. - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.get_ref().prepare_uninitialized_buffer(buf) - } } impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { diff --git a/src/io/util/chain.rs b/src/io/util/chain.rs index 8ba9194..84f37fc 100644 --- a/src/io/util/chain.rs +++ b/src/io/util/chain.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; @@ -84,26 +84,20 @@ where T: AsyncRead, U: AsyncRead, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - if self.first.prepare_uninitialized_buffer(buf) { - return true; - } - if self.second.prepare_uninitialized_buffer(buf) { - return true; - } - false - } fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { let me = self.project(); if !*me.done_first { - match ready!(me.first.poll_read(cx, buf)?) { - 0 if !buf.is_empty() => *me.done_first = true, - n => return Poll::Ready(Ok(n)), + let rem = buf.remaining(); + ready!(me.first.poll_read(cx, buf))?; + if buf.remaining() == rem { + *me.done_first = true; + } else { + return Poll::Ready(Ok(())); } } me.second.poll_read(cx, buf) diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index 7bfe296..c5981cf 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -1,30 +1,25 @@ -use crate::io::{AsyncRead, AsyncWrite}; +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { - /// A future that asynchronously copies the entire contents of a reader into a - /// writer. - /// - /// This struct is generally created by calling [`copy`][copy]. Please - /// see the documentation of `copy()` for more details. - /// - /// [copy]: copy() - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - pub struct Copy<'a, R: ?Sized, W: ?Sized> { - reader: &'a mut R, - read_done: bool, - writer: &'a mut W, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, - } +/// A future that asynchronously copies the entire contents of a reader into a +/// writer. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +struct Copy<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + read_done: bool, + writer: &'a mut W, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +cfg_io_util! { /// Asynchronously copies the entire contents of a reader into a writer. /// /// This function returns a future that will continuously read data from @@ -58,7 +53,7 @@ cfg_io_util! { /// # Ok(()) /// # } /// ``` - pub fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> Copy<'a, R, W> + pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> where R: AsyncRead + Unpin + ?Sized, W: AsyncWrite + Unpin + ?Sized, @@ -71,7 +66,7 @@ cfg_io_util! { pos: 0, cap: 0, buf: vec![0; 2048].into_boxed_slice(), - } + }.await } } @@ -88,7 +83,9 @@ where // continue. if self.pos == self.cap && !self.read_done { let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + let mut buf = ReadBuf::new(&mut me.buf); + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?; + let n = buf.filled().len(); if n == 0 { self.read_done = true; } else { @@ -122,14 +119,3 @@ where } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Copy<'_, PhantomPinned, PhantomPinned>>(); - } -} diff --git a/src/io/util/copy_buf.rs b/src/io/util/copy_buf.rs new file mode 100644 index 0000000..6831580 --- /dev/null +++ b/src/io/util/copy_buf.rs @@ -0,0 +1,102 @@ +use crate::io::{AsyncBufRead, AsyncWrite}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +cfg_io_util! { + /// A future that asynchronously copies the entire contents of a reader into a + /// writer. + /// + /// This struct is generally created by calling [`copy_buf`][copy_buf]. Please + /// see the documentation of `copy_buf()` for more details. + /// + /// [copy_buf]: copy_buf() + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + struct CopyBuf<'a, R: ?Sized, W: ?Sized> { + reader: &'a mut R, + writer: &'a mut W, + amt: u64, + } + + /// Asynchronously copies the entire contents of a reader into a writer. + /// + /// This function returns a future that will continuously read data from + /// `reader` and then write it into `writer` in a streaming fashion until + /// `reader` returns EOF. + /// + /// On success, the total number of bytes that were copied from `reader` to + /// `writer` is returned. + /// + /// + /// # Errors + /// + /// The returned future will finish with an error will return an error + /// immediately if any call to `poll_fill_buf` or `poll_write` returns an + /// error. + /// + /// # Examples + /// + /// ``` + /// use tokio::io; + /// + /// # async fn dox() -> std::io::Result<()> { + /// let mut reader: &[u8] = b"hello"; + /// let mut writer: Vec<u8> = vec![]; + /// + /// io::copy_buf(&mut reader, &mut writer).await?; + /// + /// assert_eq!(b"hello", &writer[..]); + /// # Ok(()) + /// # } + /// ``` + pub async fn copy_buf<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> + where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, + { + CopyBuf { + reader, + writer, + amt: 0, + }.await + } +} + +impl<R, W> Future for CopyBuf<'_, R, W> +where + R: AsyncBufRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + type Output = io::Result<u64>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let me = &mut *self; + let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?; + if buffer.is_empty() { + ready!(Pin::new(&mut self.writer).poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + + let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?; + if i == 0 { + return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into())); + } + self.amt += i as u64; + Pin::new(&mut *self.reader).consume(i); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn assert_unpin() { + use std::marker::PhantomPinned; + crate::is_unpin::<CopyBuf<'_, PhantomPinned, PhantomPinned>>(); + } +} diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs index 576058d..f964d18 100644 --- a/src/io/util/empty.rs +++ b/src/io/util/empty.rs @@ -1,4 +1,4 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use std::fmt; use std::io; @@ -47,16 +47,13 @@ cfg_io_util! { } impl AsyncRead for Empty { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - _: &mut [u8], - ) -> Poll<io::Result<usize>> { - Poll::Ready(Ok(0)) + _: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) } } diff --git a/src/io/util/flush.rs b/src/io/util/flush.rs index 534a516..88d60b8 100644 --- a/src/io/util/flush.rs +++ b/src/io/util/flush.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to fully flush an I/O object. /// /// Created by the [`AsyncWriteExt::flush`][flush] function. /// [flush]: crate::io::AsyncWriteExt::flush #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Flush<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn flush<A>(a: &mut A) -> Flush<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Flush { a } + Flush { + a, + _pin: PhantomPinned, + } } impl<A> Future for Flush<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); Pin::new(&mut *me.a).poll_flush(cx) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Flush<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index ee27400..b41f04a 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -83,8 +83,7 @@ impl<R> Lines<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_line( + fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<String>>> { diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs new file mode 100644 index 0000000..e91a932 --- /dev/null +++ b/src/io/util/mem.rs @@ -0,0 +1,223 @@ +//! In-process memory IO types. + +use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::loom::sync::Mutex; + +use bytes::{Buf, BytesMut}; +use std::{ + pin::Pin, + sync::Arc, + task::{self, Poll, Waker}, +}; + +/// A bidirectional pipe to read and write bytes in memory. +/// +/// A pair of `DuplexStream`s are created together, and they act as a "channel" +/// that can be used as in-memory IO types. Writing to one of the pairs will +/// allow that data to be read from the other, and vice versa. +/// +/// # Example +/// +/// ``` +/// # async fn ex() -> std::io::Result<()> { +/// # use tokio::io::{AsyncReadExt, AsyncWriteExt}; +/// let (mut client, mut server) = tokio::io::duplex(64); +/// +/// client.write_all(b"ping").await?; +/// +/// let mut buf = [0u8; 4]; +/// server.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"ping"); +/// +/// server.write_all(b"pong").await?; +/// +/// client.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"pong"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct DuplexStream { + read: Arc<Mutex<Pipe>>, + write: Arc<Mutex<Pipe>>, +} + +/// A unidirectional IO over a piece of memory. +/// +/// Data can be written to the pipe, and reading will return that data. +#[derive(Debug)] +struct Pipe { + /// The buffer storing the bytes written, also read from. + /// + /// Using a `BytesMut` because it has efficient `Buf` and `BufMut` + /// functionality already. Additionally, it can try to copy data in the + /// same buffer if there read index has advanced far enough. + buffer: BytesMut, + /// Determines if the write side has been closed. + is_closed: bool, + /// The maximum amount of bytes that can be written before returning + /// `Poll::Pending`. + max_buf_size: usize, + /// If the `read` side has been polled and is pending, this is the waker + /// for that parked task. + read_waker: Option<Waker>, + /// If the `write` side has filled the `max_buf_size` and returned + /// `Poll::Pending`, this is the waker for that parked task. + write_waker: Option<Waker>, +} + +// ===== impl DuplexStream ===== + +/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets. +/// +/// The `max_buf_size` argument is the maximum amount of bytes that can be +/// written to a side before the write returns `Poll::Pending`. +pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { + let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); + + ( + DuplexStream { + read: one.clone(), + write: two.clone(), + }, + DuplexStream { + read: two, + write: one, + }, + ) +} + +impl AsyncRead for DuplexStream { + // Previous rustc required this `self` to be `mut`, even though newer + // versions recognize it isn't needed to call `lock()`. So for + // compatibility, we include the `mut` and `allow` the lint. + // + // See https://github.com/rust-lang/rust/issues/73592 + #[allow(unused_mut)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.read.lock()).poll_read(cx, buf) + } +} + +impl AsyncWrite for DuplexStream { + #[allow(unused_mut)] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + Pin::new(&mut *self.write.lock()).poll_write(cx, buf) + } + + #[allow(unused_mut)] + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_flush(cx) + } + + #[allow(unused_mut)] + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + Pin::new(&mut *self.write.lock()).poll_shutdown(cx) + } +} + +impl Drop for DuplexStream { + fn drop(&mut self) { + // notify the other side of the closure + self.write.lock().close(); + } +} + +// ===== impl Pipe ===== + +impl Pipe { + fn new(max_buf_size: usize) -> Self { + Pipe { + buffer: BytesMut::new(), + is_closed: false, + max_buf_size, + read_waker: None, + write_waker: None, + } + } + + fn close(&mut self) { + self.is_closed = true; + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + if self.buffer.has_remaining() { + let max = self.buffer.remaining().min(buf.remaining()); + buf.put_slice(&self.buffer[..max]); + self.buffer.advance(max); + if max > 0 { + // The passed `buf` might have been empty, don't wake up if + // no bytes have been moved. + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } + Poll::Ready(Ok(())) + } else if self.is_closed { + Poll::Ready(Ok(())) + } else { + self.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let len = buf.len().min(avail); + self.buffer.extend_from_slice(&buf[..len]); + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _: &mut task::Context<'_>, + ) -> Poll<std::io::Result<()>> { + self.close(); + Poll::Ready(Ok(())) + } +} diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs index c4754ab..e75ea03 100644 --- a/src/io/util/mod.rs +++ b/src/io/util/mod.rs @@ -25,7 +25,10 @@ cfg_io_util! { mod chain; mod copy; - pub use copy::{copy, Copy}; + pub use copy::copy; + + mod copy_buf; + pub use copy_buf::copy_buf; mod empty; pub use empty::{empty, Empty}; @@ -35,6 +38,9 @@ cfg_io_util! { mod lines; pub use lines::Lines; + mod mem; + pub use mem::{duplex, DuplexStream}; + mod read; mod read_buf; mod read_exact; @@ -60,11 +66,6 @@ cfg_io_util! { mod split; pub use split::Split; - cfg_stream! { - mod stream_reader; - pub use stream_reader::{stream_reader, StreamReader}; - } - mod take; pub use take::Take; diff --git a/src/io/util/read.rs b/src/io/util/read.rs index a8ca370..edc9d5a 100644 --- a/src/io/util/read.rs +++ b/src/io/util/read.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,10 +17,14 @@ pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R> where R: AsyncRead + Unpin + ?Sized, { - Read { reader, buf } + Read { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// A future which can be used to easily read available number of bytes to fill /// a buffer. /// @@ -28,6 +34,9 @@ cfg_io_util! { pub struct Read<'a, R: ?Sized> { reader: &'a mut R, buf: &'a mut [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -37,19 +46,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read(cx, me.buf) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Read<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + let mut buf = ReadBuf::new(*me.buf); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?; + Poll::Ready(Ok(buf.filled().len())) } } diff --git a/src/io/util/read_buf.rs b/src/io/util/read_buf.rs index 6ee3d24..696deef 100644 --- a/src/io/util/read_buf.rs +++ b/src/io/util/read_buf.rs @@ -1,8 +1,10 @@ use crate::io::AsyncRead; use bytes::BufMut; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; @@ -11,16 +13,22 @@ where R: AsyncRead + Unpin, B: BufMut, { - ReadBuf { reader, buf } + ReadBuf { + reader, + buf, + _pin: PhantomPinned, + } } -cfg_io_util! { +pin_project! { /// Future returned by [`read_buf`](crate::io::AsyncReadExt::read_buf). #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadBuf<'a, R, B> { reader: &'a mut R, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -31,8 +39,34 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.reader).poll_read_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + use crate::io::ReadBuf; + use std::mem::MaybeUninit; + + let me = self.project(); + + if !me.buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = me.buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + me.buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs index 86b8412..1e8150e 100644 --- a/src/io/util/read_exact.rs +++ b/src/io/util/read_exact.rs @@ -1,7 +1,9 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -17,12 +19,12 @@ where { ReadExact { reader, - buf, - pos: 0, + buf: ReadBuf::new(buf), + _pin: PhantomPinned, } } -cfg_io_util! { +pin_project! { /// Creates a future which will read exactly enough bytes to fill `buf`, /// returning an error if EOF is hit sooner. /// @@ -31,8 +33,10 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadExact<'a, A: ?Sized> { reader: &'a mut A, - buf: &'a mut [u8], - pos: usize, + buf: ReadBuf<'a>, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -46,32 +50,20 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let mut me = self.project(); + loop { // if our buffer is empty, then we need to read some data to continue. - if self.pos < self.buf.len() { - let me = &mut *self; - let n = ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf[me.pos..]))?; - me.pos += n; - if n == 0 { + let rem = me.buf.remaining(); + if rem != 0 { + ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + if me.buf.remaining() == rem { return Err(eof()).into(); } - } - - if self.pos >= self.buf.len() { - return Poll::Ready(Ok(self.pos)); + } else { + return Poll::Ready(Ok(me.buf.capacity())); } } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadExact<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs index 9d37dc7..5b9fb7b 100644 --- a/src/io/util/read_int.rs +++ b/src/io/util/read_int.rs @@ -1,10 +1,11 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use bytes::Buf; use pin_project_lite::pin_project; use std::future::Future; use std::io; use std::io::ErrorKind::UnexpectedEof; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -16,11 +17,15 @@ macro_rules! reader { ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] src: R, buf: [u8; $bytes], read: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -30,6 +35,7 @@ macro_rules! reader { src, buf: [0; $bytes], read: 0, + _pin: PhantomPinned, } } } @@ -48,17 +54,19 @@ macro_rules! reader { } while *me.read < $bytes as u8 { - *me.read += match me - .src - .as_mut() - .poll_read(cx, &mut me.buf[*me.read as usize..]) - { + let mut buf = ReadBuf::new(&mut me.buf[*me.read as usize..]); + + *me.read += match me.src.as_mut().poll_read(cx, &mut buf) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => { - return Poll::Ready(Err(UnexpectedEof.into())); + Poll::Ready(Ok(())) => { + let n = buf.filled().len(); + if n == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + n as u8 } - Poll::Ready(Ok(n)) => n as u8, }; } @@ -75,15 +83,22 @@ macro_rules! reader8 { pin_project! { /// Future returned from `read_u8` #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<R> { #[pin] reader: R, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<R> $name<R> { pub(crate) fn new(reader: R) -> $name<R> { - $name { reader } + $name { + reader, + _pin: PhantomPinned, + } } } @@ -97,12 +112,17 @@ macro_rules! reader8 { let me = self.project(); let mut buf = [0; 1]; - match me.reader.poll_read(cx, &mut buf[..]) { + let mut buf = ReadBuf::new(&mut buf); + match me.reader.poll_read(cx, &mut buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), - Poll::Ready(Ok(0)) => Poll::Ready(Err(UnexpectedEof.into())), - Poll::Ready(Ok(1)) => Poll::Ready(Ok(buf[0] as $ty)), - Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Ok(())) => { + if buf.filled().len() == 0 { + return Poll::Ready(Err(UnexpectedEof.into())); + } + + Poll::Ready(Ok(buf.filled()[0] as $ty)) + } } } } diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs index d625a76..d38ffaf 100644 --- a/src/io/util/read_line.rs +++ b/src/io/util/read_line.rs @@ -1,26 +1,32 @@ use crate::io::util::read_until::read_until_internal; use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; +use std::string::FromUtf8Error; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_line`](crate::io::AsyncBufReadExt::read_line) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - /// This is the buffer we were provided. It will be replaced with an empty string - /// while reading to postpone utf-8 handling until after reading. + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. output: &'a mut String, - /// The actual allocation of the string is moved into a vector instead. + // The actual allocation of the string is moved into this vector instead. buf: Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -33,6 +39,7 @@ where buf: mem::replace(string, String::new()).into_bytes(), output: string, read: 0, + _pin: PhantomPinned, } } @@ -42,31 +49,33 @@ fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_re *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); } -pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( - reader: Pin<&mut R>, - cx: &mut Context<'_>, +/// This handles the various failure cases and puts the string back into `output`. +/// +/// The `truncate_on_io_error` bool is necessary because `read_to_string` and `read_line` +/// disagree on what should happen when an IO error occurs. +pub(super) fn finish_string_read( + io_res: io::Result<usize>, + utf8_res: Result<String, FromUtf8Error>, + read: usize, output: &mut String, - buf: &mut Vec<u8>, - read: &mut usize, + truncate_on_io_error: bool, ) -> Poll<io::Result<usize>> { - let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); - let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - - // At this point both buf and output are empty. The allocation is in utf8_res. - - debug_assert!(buf.is_empty()); match (io_res, utf8_res) { (Ok(num_bytes), Ok(string)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); *output = string; Poll::Ready(Ok(num_bytes)) } (Err(io_err), Ok(string)) => { *output = string; + if truncate_on_io_error { + let original_len = output.len() - read; + output.truncate(original_len); + } Poll::Ready(Err(io_err)) } (Ok(num_bytes), Err(utf8_err)) => { - debug_assert_eq!(*read, 0); + debug_assert_eq!(read, 0); put_back_original_data(output, utf8_err.into_bytes(), num_bytes); Poll::Ready(Err(io::Error::new( @@ -75,35 +84,36 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( ))) } (Err(io_err), Err(utf8_err)) => { - put_back_original_data(output, utf8_err.into_bytes(), *read); + put_back_original_data(output, utf8_err.into_bytes(), read); Poll::Ready(Err(io_err)) } } } -impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { - type Output = io::Result<usize>; +pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( + reader: Pin<&mut R>, + cx: &mut Context<'_>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, +) -> Poll<io::Result<usize>> { + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - output, - buf, - read, - } = &mut *self; + // At this point both buf and output are empty. The allocation is in utf8_res. - read_line_internal(Pin::new(reader), cx, output, buf, read) - } + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, false) } -#[cfg(test)] -mod tests { - use super::*; +impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> { + type Output = io::Result<usize>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadLine<'_, PhantomPinned>>(); + read_line_internal(Pin::new(*me.reader), cx, me.output, me.buf, me.read) } } diff --git a/src/io/util/read_to_end.rs b/src/io/util/read_to_end.rs index a2cd99b..a974625 100644 --- a/src/io/util/read_to_end.rs +++ b/src/io/util/read_to_end.rs @@ -1,92 +1,105 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; +use pin_project_lite::pin_project; use std::future::Future; use std::io; -use std::mem::MaybeUninit; +use std::marker::PhantomPinned; +use std::mem::{self, MaybeUninit}; use std::pin::Pin; use std::task::{Context, Poll}; -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub struct ReadToEnd<'a, R: ?Sized> { - reader: &'a mut R, - buf: &'a mut Vec<u8>, - start_len: usize, +pin_project! { + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct ReadToEnd<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } } -pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, R> +pub(crate) fn read_to_end<'a, R>(reader: &'a mut R, buffer: &'a mut Vec<u8>) -> ReadToEnd<'a, R> where R: AsyncRead + Unpin + ?Sized, { - let start_len = buf.len(); ReadToEnd { reader, - buf, - start_len, + buf: buffer, + read: 0, + _pin: PhantomPinned, } } -struct Guard<'a> { - buf: &'a mut Vec<u8>, - len: usize, -} - -impl Drop for Guard<'_> { - fn drop(&mut self) { - unsafe { - self.buf.set_len(self.len); +pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( + buf: &mut Vec<u8>, + mut reader: Pin<&mut R>, + num_read: &mut usize, + cx: &mut Context<'_>, +) -> Poll<io::Result<usize>> { + loop { + // safety: The caller promised to prepare the buffer. + let ret = ready!(poll_read_to_end(buf, reader.as_mut(), cx)); + match ret { + Err(err) => return Poll::Ready(Err(err)), + Ok(0) => return Poll::Ready(Ok(mem::replace(num_read, 0))), + Ok(num) => { + *num_read += num; + } } } } -// This uses an adaptive system to extend the vector when it fills. We want to -// avoid paying to allocate and zero a huge chunk of memory if the reader only -// has 4 bytes while still making large reads if the reader does have a ton -// of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every -// time is 4,500 times (!) slower than this if the reader has a very small -// amount of data to return. -// -// Because we're extending the buffer with uninitialized data for trusted -// readers, we need to make sure to truncate that if any of this panics. -pub(super) fn read_to_end_internal<R: AsyncRead + ?Sized>( - mut rd: Pin<&mut R>, - cx: &mut Context<'_>, +/// Tries to read from the provided AsyncRead. +/// +/// The length of the buffer is increased by the number of bytes read. +fn poll_read_to_end<R: AsyncRead + ?Sized>( buf: &mut Vec<u8>, - start_len: usize, + read: Pin<&mut R>, + cx: &mut Context<'_>, ) -> Poll<io::Result<usize>> { - let mut g = Guard { - len: buf.len(), - buf, - }; - let ret; - loop { - if g.len == g.buf.len() { - unsafe { - g.buf.reserve(32); - let capacity = g.buf.capacity(); - g.buf.set_len(capacity); + // This uses an adaptive system to extend the vector when it fills. We want to + // avoid paying to allocate and zero a huge chunk of memory if the reader only + // has 4 bytes while still making large reads if the reader does have a ton + // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every + // time is 4,500 times (!) slower than this if the reader has a very small + // amount of data to return. + reserve(buf, 32); - let b = &mut *(&mut g.buf[g.len..] as *mut [u8] as *mut [MaybeUninit<u8>]); + let mut unused_capacity = ReadBuf::uninit(get_unused_capacity(buf)); - rd.prepare_uninitialized_buffer(b); - } - } + ready!(read.poll_read(cx, &mut unused_capacity))?; - match ready!(rd.as_mut().poll_read(cx, &mut g.buf[g.len..])) { - Ok(0) => { - ret = Poll::Ready(Ok(g.len - start_len)); - break; - } - Ok(n) => g.len += n, - Err(e) => { - ret = Poll::Ready(Err(e)); - break; - } - } + let n = unused_capacity.filled().len(); + let new_len = buf.len() + n; + + // This should no longer even be possible in safe Rust. An implementor + // would need to have unsafely *replaced* the buffer inside `ReadBuf`, + // which... yolo? + assert!(new_len <= buf.capacity()); + unsafe { + buf.set_len(new_len); } + Poll::Ready(Ok(n)) +} - ret +/// Allocates more memory and ensures that the unused capacity is prepared for use +/// with the `AsyncRead`. +fn reserve(buf: &mut Vec<u8>, bytes: usize) { + if buf.capacity() - buf.len() >= bytes { + return; + } + buf.reserve(bytes); +} + +/// Returns the unused capacity of the provided vector. +fn get_unused_capacity(buf: &mut Vec<u8>) -> &mut [MaybeUninit<u8>] { + let uninit = bytes::BufMut::bytes_mut(buf); + unsafe { &mut *(uninit as *mut _ as *mut [MaybeUninit<u8>]) } } impl<A> Future for ReadToEnd<'_, A> @@ -95,19 +108,9 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let this = &mut *self; - read_to_end_internal(Pin::new(&mut this.reader), cx, this.buf, this.start_len) - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToEnd<'_, PhantomPinned>>(); + read_to_end_internal(me.buf, Pin::new(*me.reader), me.read, cx) } } diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs index cab0505..e463203 100644 --- a/src/io/util/read_to_string.rs +++ b/src/io/util/read_to_string.rs @@ -1,58 +1,71 @@ +use crate::io::util::read_line::finish_string_read; use crate::io::util::read_to_end::read_to_end_internal; use crate::io::AsyncRead; +use pin_project_lite::pin_project; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; -cfg_io_util! { +pin_project! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadToString<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec<u8>, - start_len: usize, + // This is the buffer we were provided. It will be replaced with an empty string + // while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + // The actual allocation of the string is moved into this vector instead. + buf: Vec<u8>, + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. + read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } -pub(crate) fn read_to_string<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadToString<'a, R> +pub(crate) fn read_to_string<'a, R>( + reader: &'a mut R, + string: &'a mut String, +) -> ReadToString<'a, R> where R: AsyncRead + ?Sized + Unpin, { - let start_len = buf.len(); + let buf = mem::replace(string, String::new()).into_bytes(); ReadToString { reader, - bytes: mem::replace(buf, String::new()).into_bytes(), buf, - start_len, + output: string, + read: 0, + _pin: PhantomPinned, } } -fn read_to_string_internal<R: AsyncRead + ?Sized>( +/// # Safety +/// +/// Before first calling this method, the unused capacity must have been +/// prepared for use with the provided AsyncRead. This can be done using the +/// `prepare_buffer` function in `read_to_end.rs`. +unsafe fn read_to_string_internal<R: AsyncRead + ?Sized>( reader: Pin<&mut R>, + output: &mut String, + buf: &mut Vec<u8>, + read: &mut usize, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec<u8>, - start_len: usize, ) -> Poll<io::Result<usize>> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; - match String::from_utf8(mem::replace(bytes, Vec::new())) { - Ok(string) => { - debug_assert!(buf.is_empty()); - *buf = string; - Poll::Ready(Ok(ret)) - } - Err(e) => { - *bytes = e.into_bytes(); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::InvalidData, - "stream did not contain valid UTF-8", - ))) - } - } + let io_res = ready!(read_to_end_internal(buf, reader, read, cx)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + debug_assert!(output.is_empty()); + finish_string_read(io_res, utf8_res, *read, output, true) } impl<A> Future for ReadToString<'_, A> @@ -61,31 +74,10 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - buf, - bytes, - start_len, - } = &mut *self; - let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); - if let Poll::Ready(Err(_)) = ret { - // Put back the original string. - bytes.truncate(*start_len); - **buf = String::from_utf8(mem::replace(bytes, Vec::new())) - .expect("original string no longer utf-8"); - } - ret - } -} - -#[cfg(test)] -mod tests { - use super::*; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadToString<'_, PhantomPinned>>(); + // safety: The constructor of ReadToString called `prepare_buffer`. + unsafe { read_to_string_internal(Pin::new(*me.reader), me.output, me.buf, me.read, cx) } } } diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs index 78dac8c..3599cff 100644 --- a/src/io/util/read_until.rs +++ b/src/io/util/read_until.rs @@ -1,12 +1,14 @@ use crate::io::AsyncBufRead; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. /// The delimeter is included in the resulting vector. #[derive(Debug)] @@ -15,9 +17,12 @@ cfg_io_util! { reader: &'a mut R, delimeter: u8, buf: &'a mut Vec<u8>, - /// The number of bytes appended to buf. This can be less than buf.len() if - /// the buffer was not empty when the operation was started. + // The number of bytes appended to buf. This can be less than buf.len() if + // the buffer was not empty when the operation was started. read: usize, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -34,6 +39,7 @@ where delimeter, buf, read: 0, + _pin: PhantomPinned, } } @@ -66,24 +72,8 @@ pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let Self { - reader, - delimeter, - buf, - read, - } = &mut *self; - read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<ReadUntil<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read) } } diff --git a/src/io/util/repeat.rs b/src/io/util/repeat.rs index eeef7cc..1142765 100644 --- a/src/io/util/repeat.rs +++ b/src/io/util/repeat.rs @@ -1,4 +1,4 @@ -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use std::io; use std::pin::Pin; @@ -47,19 +47,17 @@ cfg_io_util! { } impl AsyncRead for Repeat { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { - false - } #[inline] fn poll_read( self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - for byte in &mut *buf { - *byte = self.byte; + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + // TODO: could be faster, but should we unsafe it? + while buf.remaining() != 0 { + buf.put_slice(&[self.byte]); } - Poll::Ready(Ok(buf.len())) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/shutdown.rs b/src/io/util/shutdown.rs index 33ac0ac..6d30b00 100644 --- a/src/io/util/shutdown.rs +++ b/src/io/util/shutdown.rs @@ -1,18 +1,24 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future used to shutdown an I/O object. /// /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. /// [shutdown]: crate::io::AsyncWriteExt::shutdown + #[must_use = "futures do nothing unless you `.await` or poll them"] #[derive(Debug)] pub struct Shutdown<'a, A: ?Sized> { a: &'a mut A, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +27,10 @@ pub(super) fn shutdown<A>(a: &mut A) -> Shutdown<'_, A> where A: AsyncWrite + Unpin + ?Sized, { - Shutdown { a } + Shutdown { + a, + _pin: PhantomPinned, + } } impl<A> Future for Shutdown<'_, A> @@ -30,19 +39,8 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let me = &mut *self; - Pin::new(&mut *me.a).poll_shutdown(cx) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<Shutdown<'_, PhantomPinned>>(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + Pin::new(me.a).poll_shutdown(cx) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index f552ed5..492e26a 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -65,8 +65,7 @@ impl<R> Split<R> where R: AsyncBufRead, { - #[doc(hidden)] - pub fn poll_next_segment( + fn poll_next_segment( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<Option<Vec<u8>>>> { diff --git a/src/io/util/stream_reader.rs b/src/io/util/stream_reader.rs deleted file mode 100644 index b98f8bd..0000000 --- a/src/io/util/stream_reader.rs +++ /dev/null @@ -1,184 +0,0 @@ -use crate::io::{AsyncBufRead, AsyncRead}; -use crate::stream::Stream; -use bytes::{Buf, BufMut}; -use pin_project_lite::pin_project; -use std::io; -use std::mem::MaybeUninit; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pin_project! { - /// Convert a stream of byte chunks into an [`AsyncRead`]. - /// - /// This type is usually created using the [`stream_reader`] function. - /// - /// [`AsyncRead`]: crate::io::AsyncRead - /// [`stream_reader`]: crate::io::stream_reader - #[derive(Debug)] - #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] - #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] - pub struct StreamReader<S, B> { - #[pin] - inner: S, - chunk: Option<B>, - } -} - -/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead). -/// -/// # Example -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::{stream_reader, AsyncReadExt}; -/// # #[tokio::main] -/// # async fn main() -> std::io::Result<()> { -/// -/// // Create a stream from an iterator. -/// let stream = tokio::stream::iter(vec![ -/// Ok(Bytes::from_static(&[0, 1, 2, 3])), -/// Ok(Bytes::from_static(&[4, 5, 6, 7])), -/// Ok(Bytes::from_static(&[8, 9, 10, 11])), -/// ]); -/// -/// // Convert it to an AsyncRead. -/// let mut read = stream_reader(stream); -/// -/// // Read five bytes from the stream. -/// let mut buf = [0; 5]; -/// read.read_exact(&mut buf).await?; -/// assert_eq!(buf, [0, 1, 2, 3, 4]); -/// -/// // Read the rest of the current chunk. -/// assert_eq!(read.read(&mut buf).await?, 3); -/// assert_eq!(&buf[..3], [5, 6, 7]); -/// -/// // Read the next chunk. -/// assert_eq!(read.read(&mut buf).await?, 4); -/// assert_eq!(&buf[..4], [8, 9, 10, 11]); -/// -/// // We have now reached the end. -/// assert_eq!(read.read(&mut buf).await?, 0); -/// -/// # Ok(()) -/// # } -/// ``` -#[cfg_attr(docsrs, doc(cfg(feature = "stream")))] -#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] -pub fn stream_reader<S, B>(stream: S) -> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - StreamReader::new(stream) -} - -impl<S, B> StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - /// Convert the provided stream into an `AsyncRead`. - fn new(stream: S) -> Self { - Self { - inner: stream, - chunk: None, - } - } - /// Do we have a chunk and is it non-empty? - fn has_chunk(self: Pin<&mut Self>) -> bool { - if let Some(chunk) = self.project().chunk { - chunk.remaining() > 0 - } else { - false - } - } -} - -impl<S, B> AsyncRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<io::Result<usize>> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.len()); - (&mut buf[..len]).copy_from_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - fn poll_read_buf<BM: BufMut>( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut BM, - ) -> Poll<io::Result<usize>> - where - Self: Sized, - { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let inner_buf = match self.as_mut().poll_fill_buf(cx) { - Poll::Ready(Ok(buf)) => buf, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }; - let len = std::cmp::min(inner_buf.len(), buf.remaining_mut()); - buf.put_slice(&inner_buf[..len]); - - self.consume(len); - Poll::Ready(Ok(len)) - } - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit<u8>]) -> bool { - false - } -} - -impl<S, B> AsyncBufRead for StreamReader<S, B> -where - S: Stream<Item = Result<B, io::Error>>, - B: Buf, -{ - fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { - loop { - if self.as_mut().has_chunk() { - // This unwrap is very sad, but it can't be avoided. - let buf = self.project().chunk.as_ref().unwrap().bytes(); - return Poll::Ready(Ok(buf)); - } else { - match self.as_mut().project().inner.poll_next(cx) { - Poll::Ready(Some(Ok(chunk))) => { - // Go around the loop in case the chunk is empty. - *self.as_mut().project().chunk = Some(chunk); - } - Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), - Poll::Ready(None) => return Poll::Ready(Ok(&[])), - Poll::Pending => return Poll::Pending, - } - } - } - } - fn consume(self: Pin<&mut Self>, amt: usize) { - if amt > 0 { - self.project() - .chunk - .as_mut() - .expect("No chunk present") - .advance(amt); - } - } -} diff --git a/src/io/util/take.rs b/src/io/util/take.rs index 5d6bd90..b5e90c9 100644 --- a/src/io/util/take.rs +++ b/src/io/util/take.rs @@ -1,7 +1,6 @@ -use crate::io::{AsyncBufRead, AsyncRead}; +use crate::io::{AsyncBufRead, AsyncRead, ReadBuf}; use pin_project_lite::pin_project; -use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use std::{cmp, io}; @@ -76,24 +75,27 @@ impl<R: AsyncRead> Take<R> { } impl<R: AsyncRead> AsyncRead for Take<R> { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll<Result<usize, io::Error>> { + buf: &mut ReadBuf<'_>, + ) -> Poll<Result<(), io::Error>> { if self.limit_ == 0 { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } let me = self.project(); - let max = std::cmp::min(buf.len() as u64, *me.limit_) as usize; - let n = ready!(me.inner.poll_read(cx, &mut buf[..max]))?; + let mut b = buf.take(*me.limit_ as usize); + ready!(me.inner.poll_read(cx, &mut b))?; + let n = b.filled().len(); + + // We need to update the original ReadBuf + unsafe { + buf.assume_init(n); + } + buf.advance(n); *me.limit_ -= n as u64; - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } diff --git a/src/io/util/write.rs b/src/io/util/write.rs index 433a421..92169eb 100644 --- a/src/io/util/write.rs +++ b/src/io/util/write.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Write<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -21,7 +26,11 @@ pub(crate) fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> Write<'a, W> where W: AsyncWrite + Unpin + ?Sized, { - Write { writer, buf } + Write { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for Write<'_, W> @@ -30,8 +39,8 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); Pin::new(&mut *me.writer).poll_write(cx, me.buf) } } diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs index 898006c..e59d41e 100644 --- a/src/io/util/write_all.rs +++ b/src/io/util/write_all.rs @@ -1,17 +1,22 @@ use crate::io::AsyncWrite; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteAll<'a, W: ?Sized> { writer: &'a mut W, buf: &'a [u8], + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } @@ -19,7 +24,11 @@ pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, where W: AsyncWrite + Unpin + ?Sized, { - WriteAll { writer, buf } + WriteAll { + writer, + buf, + _pin: PhantomPinned, + } } impl<W> Future for WriteAll<'_, W> @@ -28,13 +37,13 @@ where { type Output = io::Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { - let me = &mut *self; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); while !me.buf.is_empty() { - let n = ready!(Pin::new(&mut me.writer).poll_write(cx, me.buf))?; + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; { - let (_, rest) = mem::replace(&mut me.buf, &[]).split_at(n); - me.buf = rest; + let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n); + *me.buf = rest; } if n == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); @@ -44,14 +53,3 @@ where Poll::Ready(Ok(())) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn assert_unpin() { - use std::marker::PhantomPinned; - crate::is_unpin::<WriteAll<'_, PhantomPinned>>(); - } -} diff --git a/src/io/util/write_buf.rs b/src/io/util/write_buf.rs index cedfde6..1310e5c 100644 --- a/src/io/util/write_buf.rs +++ b/src/io/util/write_buf.rs @@ -1,18 +1,22 @@ use crate::io::AsyncWrite; use bytes::Buf; +use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_io_util! { +pin_project! { /// A future to write some of the buffer to an `AsyncWrite`. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WriteBuf<'a, W, B> { writer: &'a mut W, buf: &'a mut B, + #[pin] + _pin: PhantomPinned, } } @@ -23,7 +27,11 @@ where W: AsyncWrite + Unpin, B: Buf, { - WriteBuf { writer, buf } + WriteBuf { + writer, + buf, + _pin: PhantomPinned, + } } impl<W, B> Future for WriteBuf<'_, W, B> @@ -33,8 +41,15 @@ where { type Output = io::Result<usize>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let me = &mut *self; - Pin::new(&mut *me.writer).poll_write_buf(cx, me.buf) + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { + let me = self.project(); + + if !me.buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.bytes()))?; + me.buf.advance(n); + Poll::Ready(Ok(n)) } } diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs index ee992de..13bc191 100644 --- a/src/io/util/write_int.rs +++ b/src/io/util/write_int.rs @@ -4,6 +4,7 @@ use bytes::BufMut; use pin_project_lite::pin_project; use std::future::Future; use std::io; +use std::marker::PhantomPinned; use std::mem::size_of; use std::pin::Pin; use std::task::{Context, Poll}; @@ -15,20 +16,25 @@ macro_rules! writer { ($name:ident, $ty:ty, $writer:ident, $bytes:expr) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, buf: [u8; $bytes], written: u8, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(w: W, value: $ty) -> Self { - let mut writer = $name { + let mut writer = Self { buf: [0; $bytes], written: 0, dst: w, + _pin: PhantomPinned, }; BufMut::$writer(&mut &mut writer.buf[..], value); writer @@ -72,16 +78,24 @@ macro_rules! writer8 { ($name:ident, $ty:ty) => { pin_project! { #[doc(hidden)] + #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct $name<W> { #[pin] dst: W, byte: $ty, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, } } impl<W> $name<W> { pub(crate) fn new(dst: W, byte: $ty) -> Self { - Self { dst, byte } + Self { + dst, + byte, + _pin: PhantomPinned, + } } } |