aboutsummaryrefslogtreecommitdiff
path: root/src/codec/framed_impl.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/codec/framed_impl.rs')
-rw-r--r--src/codec/framed_impl.rs312
1 files changed, 312 insertions, 0 deletions
diff --git a/src/codec/framed_impl.rs b/src/codec/framed_impl.rs
new file mode 100644
index 0000000..8f3fa49
--- /dev/null
+++ b/src/codec/framed_impl.rs
@@ -0,0 +1,312 @@
+use crate::codec::decoder::Decoder;
+use crate::codec::encoder::Encoder;
+
+use futures_core::Stream;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use bytes::BytesMut;
+use futures_core::ready;
+use futures_sink::Sink;
+use pin_project_lite::pin_project;
+use std::borrow::{Borrow, BorrowMut};
+use std::io;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tracing::trace;
+
+pin_project! {
+ #[derive(Debug)]
+ pub(crate) struct FramedImpl<T, U, State> {
+ #[pin]
+ pub(crate) inner: T,
+ pub(crate) state: State,
+ pub(crate) codec: U,
+ }
+}
+
+const INITIAL_CAPACITY: usize = 8 * 1024;
+
+#[derive(Debug)]
+pub(crate) struct ReadFrame {
+ pub(crate) eof: bool,
+ pub(crate) is_readable: bool,
+ pub(crate) buffer: BytesMut,
+ pub(crate) has_errored: bool,
+}
+
+pub(crate) struct WriteFrame {
+ pub(crate) buffer: BytesMut,
+ pub(crate) backpressure_boundary: usize,
+}
+
+#[derive(Default)]
+pub(crate) struct RWFrames {
+ pub(crate) read: ReadFrame,
+ pub(crate) write: WriteFrame,
+}
+
+impl Default for ReadFrame {
+ fn default() -> Self {
+ Self {
+ eof: false,
+ is_readable: false,
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ has_errored: false,
+ }
+ }
+}
+
+impl Default for WriteFrame {
+ fn default() -> Self {
+ Self {
+ buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
+ backpressure_boundary: INITIAL_CAPACITY,
+ }
+ }
+}
+
+impl From<BytesMut> for ReadFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self {
+ buffer,
+ is_readable: size > 0,
+ eof: false,
+ has_errored: false,
+ }
+ }
+}
+
+impl From<BytesMut> for WriteFrame {
+ fn from(mut buffer: BytesMut) -> Self {
+ let size = buffer.capacity();
+ if size < INITIAL_CAPACITY {
+ buffer.reserve(INITIAL_CAPACITY - size);
+ }
+
+ Self {
+ buffer,
+ backpressure_boundary: INITIAL_CAPACITY,
+ }
+ }
+}
+
+impl Borrow<ReadFrame> for RWFrames {
+ fn borrow(&self) -> &ReadFrame {
+ &self.read
+ }
+}
+impl BorrowMut<ReadFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut ReadFrame {
+ &mut self.read
+ }
+}
+impl Borrow<WriteFrame> for RWFrames {
+ fn borrow(&self) -> &WriteFrame {
+ &self.write
+ }
+}
+impl BorrowMut<WriteFrame> for RWFrames {
+ fn borrow_mut(&mut self) -> &mut WriteFrame {
+ &mut self.write
+ }
+}
+impl<T, U, R> Stream for FramedImpl<T, U, R>
+where
+ T: AsyncRead,
+ U: Decoder,
+ R: BorrowMut<ReadFrame>,
+{
+ type Item = Result<U::Item, U::Error>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ use crate::util::poll_read_buf;
+
+ let mut pinned = self.project();
+ let state: &mut ReadFrame = pinned.state.borrow_mut();
+ // The following loops implements a state machine with each state corresponding
+ // to a combination of the `is_readable` and `eof` flags. States persist across
+ // loop entries and most state transitions occur with a return.
+ //
+ // The initial state is `reading`.
+ //
+ // | state | eof | is_readable | has_errored |
+ // |---------|-------|-------------|-------------|
+ // | reading | false | false | false |
+ // | framing | false | true | false |
+ // | pausing | true | true | false |
+ // | paused | true | false | false |
+ // | errored | <any> | <any> | true |
+ // `decode_eof` returns Err
+ // ┌────────────────────────────────────────────────────────┐
+ // `decode_eof` returns │ │
+ // `Ok(Some)` │ │
+ // ┌─────┐ │ `decode_eof` returns After returning │
+ // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
+ // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
+ // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
+ // Pending read │ │ │ │ │ │
+ // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
+ // │ │ │ ┌──────┐ │ Pending │ │
+ // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
+ // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
+ // └──┬─▲────┘ └─────┬──┬┘ │ │
+ // │ │ │ │ `decode` returns Err │ │
+ // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
+ // │ read returns Err │
+ // └────────────────────────────────────────────────────────────────────────────────────────────┘
+ loop {
+ // Return `None` if we have encountered an error from the underlying decoder
+ // See: https://github.com/tokio-rs/tokio/issues/3976
+ if state.has_errored {
+ // preparing has_errored -> paused
+ trace!("Returning None and setting paused");
+ state.is_readable = false;
+ state.has_errored = false;
+ return Poll::Ready(None);
+ }
+
+ // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
+ // i.e. it _might_ contain data consumable as a frame or closing frame.
+ // Both signal that there is no such data by returning `None`.
+ //
+ // If `decode` couldn't read a frame and the upstream source has returned eof,
+ // `decode_eof` will attempt to decode the remaining bytes as closing frames.
+ //
+ // If the underlying AsyncRead is resumable, we may continue after an EOF,
+ // but must finish emitting all of it's associated `decode_eof` frames.
+ // Furthermore, we don't want to emit any `decode_eof` frames on retried
+ // reads after an EOF unless we've actually read more data.
+ if state.is_readable {
+ // pausing or framing
+ if state.eof {
+ // pausing
+ let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ })?;
+ if frame.is_none() {
+ state.is_readable = false; // prepare pausing -> paused
+ }
+ // implicit pausing -> pausing or pausing -> paused
+ return Poll::Ready(frame.map(Ok));
+ }
+
+ // framing
+ trace!("attempting to decode a frame");
+
+ if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ op
+ })? {
+ trace!("frame decoded from buffer");
+ // implicit framing -> framing
+ return Poll::Ready(Some(Ok(frame)));
+ }
+
+ // framing -> reading
+ state.is_readable = false;
+ }
+ // reading or paused
+ // If we can't build a frame yet, try to read more data and try again.
+ // Make sure we've got room for at least one byte to read to ensure
+ // that we don't get a spurious 0 that looks like EOF.
+ state.buffer.reserve(1);
+ let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
+ |err| {
+ trace!("Got an error, going to errored state");
+ state.has_errored = true;
+ err
+ },
+ )? {
+ Poll::Ready(ct) => ct,
+ // implicit reading -> reading or implicit paused -> paused
+ Poll::Pending => return Poll::Pending,
+ };
+ if bytect == 0 {
+ if state.eof {
+ // We're already at an EOF, and since we've reached this path
+ // we're also not readable. This implies that we've already finished
+ // our `decode_eof` handling, so we can simply return `None`.
+ // implicit paused -> paused
+ return Poll::Ready(None);
+ }
+ // prepare reading -> paused
+ state.eof = true;
+ } else {
+ // prepare paused -> framing or noop reading -> framing
+ state.eof = false;
+ }
+
+ // paused -> framing or reading -> framing or reading -> pausing
+ state.is_readable = true;
+ }
+ }
+}
+
+impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
+where
+ T: AsyncWrite,
+ U: Encoder<I>,
+ U::Error: From<io::Error>,
+ W: BorrowMut<WriteFrame>,
+{
+ type Error = U::Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
+ self.as_mut().poll_flush(cx)
+ } else {
+ Poll::Ready(Ok(()))
+ }
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
+ let pinned = self.project();
+ pinned
+ .codec
+ .encode(item, &mut pinned.state.borrow_mut().buffer)?;
+ Ok(())
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ use crate::util::poll_write_buf;
+ trace!("flushing framed transport");
+ let mut pinned = self.project();
+
+ while !pinned.state.borrow_mut().buffer.is_empty() {
+ let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
+ trace!(remaining = buffer.len(), "writing;");
+
+ let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
+
+ if n == 0 {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::WriteZero,
+ "failed to \
+ write frame to transport",
+ )
+ .into()));
+ }
+ }
+
+ // Try flushing the underlying IO
+ ready!(pinned.inner.poll_flush(cx))?;
+
+ trace!("framed transport flushed");
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ ready!(self.project().inner.poll_shutdown(cx))?;
+
+ Poll::Ready(Ok(()))
+ }
+}