diff options
Diffstat (limited to 'src/io/util')
-rw-r--r-- | src/io/util/async_seek_ext.rs | 2 | ||||
-rw-r--r-- | src/io/util/async_write_ext.rs | 10 | ||||
-rw-r--r-- | src/io/util/buf_reader.rs | 3 | ||||
-rw-r--r-- | src/io/util/copy.rs | 62 | ||||
-rw-r--r-- | src/io/util/empty.rs | 20 | ||||
-rw-r--r-- | src/io/util/fill_buf.rs | 6 | ||||
-rw-r--r-- | src/io/util/mem.rs | 64 | ||||
-rw-r--r-- | src/io/util/read_exact.rs | 4 | ||||
-rw-r--r-- | src/io/util/take.rs | 4 | ||||
-rw-r--r-- | src/io/util/vec_with_initialized.rs | 25 | ||||
-rw-r--r-- | src/io/util/write_all.rs | 2 |
11 files changed, 157 insertions, 45 deletions
diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index 46b3e6c..aadf3a7 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -69,7 +69,7 @@ cfg_io_util! { /// Creates a future which will rewind to the beginning of the stream. /// - /// This is convenience method, equivalent to to `self.seek(SeekFrom::Start(0))`. + /// This is convenience method, equivalent to `self.seek(SeekFrom::Start(0))`. fn rewind(&mut self) -> Seek<'_, Self> where Self: Unpin, diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index 93a3183..dfdde82 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -406,7 +406,7 @@ cfg_io_util! { /// ``` fn write_u8(&mut self, n: u8) -> WriteU8; - /// Writes an unsigned 8-bit integer to the underlying writer. + /// Writes a signed 8-bit integer to the underlying writer. /// /// Equivalent to: /// @@ -425,7 +425,7 @@ cfg_io_util! { /// /// # Examples /// - /// Write unsigned 8 bit integers to a `AsyncWrite`: + /// Write signed 8 bit integers to a `AsyncWrite`: /// /// ```rust /// use tokio::io::{self, AsyncWriteExt}; @@ -434,10 +434,10 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_u8(2).await?; - /// writer.write_u8(5).await?; + /// writer.write_i8(-2).await?; + /// writer.write_i8(126).await?; /// - /// assert_eq!(writer, b"\x02\x05"); + /// assert_eq!(writer, b"\xFE\x7E"); /// Ok(()) /// } /// ``` diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index 7df610b..60879c0 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -204,7 +204,6 @@ impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { self.as_mut() .get_pin_mut() .start_seek(SeekFrom::Current(offset))?; - self.as_mut().get_pin_mut().poll_complete(cx)? } else { // seek backwards by our remainder, and then by the offset self.as_mut() @@ -221,8 +220,8 @@ impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { self.as_mut() .get_pin_mut() .start_seek(SeekFrom::Current(n))?; - self.as_mut().get_pin_mut().poll_complete(cx)? } + self.as_mut().get_pin_mut().poll_complete(cx)? } SeekState::PendingOverflowed(n) => { if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index d0ab7cb..47dad89 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -27,6 +27,51 @@ impl CopyBuffer { } } + fn poll_fill_buf<R>( + &mut self, + cx: &mut Context<'_>, + reader: Pin<&mut R>, + ) -> Poll<io::Result<()>> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(_)) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf<R, W>( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll<io::Result<usize>> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + pub(super) fn poll_copy<R, W>( &mut self, cx: &mut Context<'_>, @@ -41,10 +86,10 @@ impl CopyBuffer { // If our buffer is empty, then we need to read some data to // continue. if self.pos == self.cap && !self.read_done { - let me = &mut *self; - let mut buf = ReadBuf::new(&mut me.buf); + self.pos = 0; + self.cap = 0; - match reader.as_mut().poll_read(cx, &mut buf) { + match self.poll_fill_buf(cx, reader.as_mut()) { Poll::Ready(Ok(_)) => (), Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => { @@ -58,20 +103,11 @@ impl CopyBuffer { return Poll::Pending; } } - - let n = buf.filled().len(); - if n == 0 { - self.read_done = true; - } else { - self.pos = 0; - self.cap = n; - } } // If our buffer has some data, let's write it out! while self.pos < self.cap { - let me = &mut *self; - let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; if i == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, diff --git a/src/io/util/empty.rs b/src/io/util/empty.rs index f964d18..9e648f8 100644 --- a/src/io/util/empty.rs +++ b/src/io/util/empty.rs @@ -50,16 +50,18 @@ impl AsyncRead for Empty { #[inline] fn poll_read( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, _: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { + ready!(poll_proceed_and_make_progress(cx)); Poll::Ready(Ok(())) } } impl AsyncBufRead for Empty { #[inline] - fn poll_fill_buf(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + ready!(poll_proceed_and_make_progress(cx)); Poll::Ready(Ok(&[])) } @@ -73,6 +75,20 @@ impl fmt::Debug for Empty { } } +cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut Context<'_>) -> Poll<()> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + } +} + +cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/io/util/fill_buf.rs b/src/io/util/fill_buf.rs index 3655c01..bb07c76 100644 --- a/src/io/util/fill_buf.rs +++ b/src/io/util/fill_buf.rs @@ -40,6 +40,12 @@ impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for FillBuf<'a, R> { // Safety: This is necessary only due to a limitation in the // borrow checker. Once Rust starts using the polonius borrow // checker, this can be simplified. + // + // The safety of this transmute relies on the fact that the + // value of `reader` is `None` when we return in this branch. + // Otherwise the caller could poll us again after + // completion, and access the mutable reference while the + // returned immutable reference still exists. let slice = std::mem::transmute::<&[u8], &'a [u8]>(slice); Poll::Ready(Ok(slice)) }, diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs index 4eefe7b..31884b3 100644 --- a/src/io/util/mem.rs +++ b/src/io/util/mem.rs @@ -177,10 +177,8 @@ impl Pipe { waker.wake(); } } -} -impl AsyncRead for Pipe { - fn poll_read( + fn poll_read_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, @@ -204,10 +202,8 @@ impl AsyncRead for Pipe { Poll::Pending } } -} -impl AsyncWrite for Pipe { - fn poll_write( + fn poll_write_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], @@ -228,6 +224,62 @@ impl AsyncWrite for Pipe { } Poll::Ready(Ok(len)) } +} + +impl AsyncRead for Pipe { + cfg_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + + let ret = self.poll_read_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<std::io::Result<()>> { + self.poll_read_internal(cx, buf) + } + } +} + +impl AsyncWrite for Pipe { + cfg_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + + let ret = self.poll_write_internal(cx, buf); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll<std::io::Result<usize>> { + self.poll_write_internal(cx, buf) + } + } fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> { Poll::Ready(Ok(())) diff --git a/src/io/util/read_exact.rs b/src/io/util/read_exact.rs index 1e8150e..dbdd58b 100644 --- a/src/io/util/read_exact.rs +++ b/src/io/util/read_exact.rs @@ -51,13 +51,13 @@ where type Output = io::Result<usize>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { - let mut me = self.project(); + let me = self.project(); loop { // if our buffer is empty, then we need to read some data to continue. let rem = me.buf.remaining(); if rem != 0 { - ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut me.buf))?; + ready!(Pin::new(&mut *me.reader).poll_read(cx, me.buf))?; if me.buf.remaining() == rem { return Err(eof()).into(); } diff --git a/src/io/util/take.rs b/src/io/util/take.rs index b5e90c9..df2f61b 100644 --- a/src/io/util/take.rs +++ b/src/io/util/take.rs @@ -86,7 +86,11 @@ impl<R: AsyncRead> AsyncRead for Take<R> { let me = self.project(); let mut b = buf.take(*me.limit_ as usize); + + let buf_ptr = b.filled().as_ptr(); ready!(me.inner.poll_read(cx, &mut b))?; + assert_eq!(b.filled().as_ptr(), buf_ptr); + let n = b.filled().len(); // We need to update the original ReadBuf diff --git a/src/io/util/vec_with_initialized.rs b/src/io/util/vec_with_initialized.rs index 208cc93..a9b94e3 100644 --- a/src/io/util/vec_with_initialized.rs +++ b/src/io/util/vec_with_initialized.rs @@ -1,19 +1,18 @@ use crate::io::ReadBuf; use std::mem::MaybeUninit; -mod private { - pub trait Sealed {} - - impl Sealed for Vec<u8> {} - impl Sealed for &mut Vec<u8> {} -} +/// Something that looks like a `Vec<u8>`. +/// +/// # Safety +/// +/// The implementor must guarantee that the vector returned by the +/// `as_mut` and `as_mut` methods do not change from one call to +/// another. +pub(crate) unsafe trait VecU8: AsRef<Vec<u8>> + AsMut<Vec<u8>> {} -/// A sealed trait that constrains the generic type parameter in `VecWithInitialized<V>`. That struct's safety relies -/// on certain invariants upheld by `Vec<u8>`. -pub(crate) trait VecU8: AsMut<Vec<u8>> + private::Sealed {} +unsafe impl VecU8 for Vec<u8> {} +unsafe impl VecU8 for &mut Vec<u8> {} -impl VecU8 for Vec<u8> {} -impl VecU8 for &mut Vec<u8> {} /// This struct wraps a `Vec<u8>` or `&mut Vec<u8>`, combining it with a /// `num_initialized`, which keeps track of the number of initialized bytes /// in the unused capacity. @@ -64,8 +63,8 @@ where } #[cfg(feature = "io-util")] - pub(crate) fn is_empty(&mut self) -> bool { - self.vec.as_mut().is_empty() + pub(crate) fn is_empty(&self) -> bool { + self.vec.as_ref().is_empty() } pub(crate) fn get_read_buf<'a>(&'a mut self) -> ReadBuf<'a> { diff --git a/src/io/util/write_all.rs b/src/io/util/write_all.rs index e59d41e..abd3e39 100644 --- a/src/io/util/write_all.rs +++ b/src/io/util/write_all.rs @@ -42,7 +42,7 @@ where while !me.buf.is_empty() { let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf))?; { - let (_, rest) = mem::replace(&mut *me.buf, &[]).split_at(n); + let (_, rest) = mem::take(&mut *me.buf).split_at(n); *me.buf = rest; } if n == 0 { |