diff options
Diffstat (limited to 'doh/connection/driver.rs')
-rw-r--r-- | doh/connection/driver.rs | 369 |
1 files changed, 369 insertions, 0 deletions
diff --git a/doh/connection/driver.rs b/doh/connection/driver.rs new file mode 100644 index 00000000..4fd1e266 --- /dev/null +++ b/doh/connection/driver.rs @@ -0,0 +1,369 @@ +/* +* Copyright (C) 2021 The Android Open Source Project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +//! Defines a backing task to keep a HTTP/3 connection running + +use crate::boot_time; +use crate::boot_time::BootTime; +use log::warn; +use quiche::h3; +use std::collections::HashMap; +use std::default::Default; +use std::future; +use std::io; +use std::pin::Pin; +use thiserror::Error; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("network IO error: {0}")] + Network(#[from] io::Error), + #[error("QUIC error: {0}")] + Quic(#[from] quiche::Error), + #[error("HTTP/3 error: {0}")] + H3(#[from] h3::Error), + #[error("Response delivery error: {0}")] + StreamSend(#[from] mpsc::error::SendError<Stream>), + #[error("Connection closed")] + Closed, +} + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug)] +/// HTTP/3 Request to be sent on the connection +pub struct Request { + /// Request headers + pub headers: Vec<h3::Header>, + /// Expiry time for the request, relative to `CLOCK_BOOTTIME` + pub expiry: Option<BootTime>, + /// Channel to send the response to + pub response_tx: oneshot::Sender<Stream>, +} + +#[derive(Debug)] +/// HTTP/3 Response +pub struct Stream { + /// Response headers + pub headers: Vec<h3::Header>, + /// Response body + pub data: Vec<u8>, + /// Error code if stream was reset + pub error: Option<u64>, +} + +impl Stream { + fn new(headers: Vec<h3::Header>) -> Self { + Self { headers, data: Vec::new(), error: None } + } +} + +const MAX_UDP_PACKET_SIZE: usize = 65536; + +struct Driver { + request_rx: mpsc::Receiver<Request>, + quiche_conn: Pin<Box<quiche::Connection>>, + socket: UdpSocket, + // This buffer is large, boxing it will keep it + // off the stack and prevent it being copied during + // moves of the driver. + buffer: Box<[u8; MAX_UDP_PACKET_SIZE]>, +} + +struct H3Driver { + driver: Driver, + // h3_conn sometimes can't "fit" a request in its available windows. + // This value holds a peeked request in that case, waiting for + // transmission to become possible. + buffered_request: Option<Request>, + // We can't check if a receiver is dead without potentially receiving a message, and if we poll + // on a dead receiver in a select! it will immediately return None. As a result, we need this + // to gate whether or not to include .recv() in our select! + closing: bool, + h3_conn: h3::Connection, + requests: HashMap<u64, Request>, + streams: HashMap<u64, Stream>, +} + +async fn optional_timeout(timeout: Option<boot_time::Duration>) { + match timeout { + Some(timeout) => boot_time::sleep(timeout).await, + None => future::pending().await, + } +} + +/// Creates a future which when polled will handle events related to a HTTP/3 connection. +/// The returned error code will explain why the connection terminated. +pub async fn drive( + request_rx: mpsc::Receiver<Request>, + quiche_conn: Pin<Box<quiche::Connection>>, + socket: UdpSocket, +) -> Result<()> { + Driver::new(request_rx, quiche_conn, socket).drive().await +} + +impl Driver { + fn new( + request_rx: mpsc::Receiver<Request>, + quiche_conn: Pin<Box<quiche::Connection>>, + socket: UdpSocket, + ) -> Self { + Self { request_rx, quiche_conn, socket, buffer: Box::new([0; MAX_UDP_PACKET_SIZE]) } + } + + async fn drive(mut self) -> Result<()> { + // Prime connection + self.flush_tx().await?; + loop { + self = self.drive_once().await? + } + } + + fn handle_closed(&self) -> Result<()> { + if self.quiche_conn.is_closed() { + Err(Error::Closed) + } else { + Ok(()) + } + } + + async fn drive_once(mut self) -> Result<Self> { + let timer = optional_timeout(self.quiche_conn.timeout()); + select! { + // If a quiche timer would fire, call their callback + _ = timer => self.quiche_conn.on_timeout(), + // If we got packets from our peer, pass them to quiche + Ok((size, from)) = self.socket.recv_from(self.buffer.as_mut()) => { + self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from })?; + } + }; + // Any of the actions in the select could require us to send packets to the peer + self.flush_tx().await?; + + // If the QUIC connection is live, but the HTTP/3 is not, try to bring it up + if self.quiche_conn.is_established() { + let h3_config = h3::Config::new()?; + let h3_conn = h3::Connection::with_transport(&mut self.quiche_conn, &h3_config)?; + return H3Driver::new(self, h3_conn).drive().await; + } + + // If the connection has closed, tear down + self.handle_closed()?; + + Ok(self) + } + + async fn flush_tx(&mut self) -> Result<()> { + let send_buf = self.buffer.as_mut(); + loop { + match self.quiche_conn.send(send_buf) { + Err(quiche::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + Ok((valid_len, send_info)) => { + self.socket.send_to(&send_buf[..valid_len], send_info.to).await?; + } + } + } + } +} + +impl H3Driver { + fn new(driver: Driver, h3_conn: h3::Connection) -> Self { + Self { + driver, + h3_conn, + closing: false, + requests: HashMap::new(), + streams: HashMap::new(), + buffered_request: None, + } + } + + async fn drive(mut self) -> Result<Driver> { + loop { + self.drive_once().await?; + } + } + + async fn drive_once(&mut self) -> Result<()> { + // We can't call self.driver.drive_once at the same time as + // self.driver.request_rx.recv() due to ownership + let timer = optional_timeout(self.driver.quiche_conn.timeout()); + // If we've buffered a request (due to the connection being full) + // try to resend that first + if let Some(request) = self.buffered_request.take() { + self.handle_request(request)?; + } + select! { + // Only attempt to enqueue new requests if we have no buffered request and aren't + // closing + msg = self.driver.request_rx.recv(), if !self.closing && self.buffered_request.is_none() => match msg { + Some(request) => self.handle_request(request)?, + None => self.shutdown(true, b"DONE").await?, + }, + // If a quiche timer would fire, call their callback + _ = timer => self.driver.quiche_conn.on_timeout(), + // If we got packets from our peer, pass them to quiche + Ok((size, from)) = self.driver.socket.recv_from(self.driver.buffer.as_mut()) => { + self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from })?; + } + }; + + // Any of the actions in the select could require us to send packets to the peer + self.driver.flush_tx().await?; + + // Process any incoming HTTP/3 events + self.flush_h3().await?; + + // If the connection has closed, tear down + self.driver.handle_closed() + } + + fn handle_request(&mut self, request: Request) -> Result<()> { + // If the request has already timed out, don't issue it to the server. + if let Some(expiry) = request.expiry { + if BootTime::now() > expiry { + return Ok(()); + } + } + let stream_id = + // If h3_conn says the stream is blocked, this error is recoverable just by trying + // again once the stream has made progress. Buffer the request for a later retry. + match self.h3_conn.send_request(&mut self.driver.quiche_conn, &request.headers, true) { + Err(h3::Error::StreamBlocked) | Err(h3::Error::TransportError(quiche::Error::StreamLimit)) => { + // We only call handle_request on a value that has just come out of + // buffered_request, or when buffered_request is empty. This assert just + // validates that we don't break that assumption later, as it could result in + // requests being dropped on the floor under high load. + assert!(self.buffered_request.is_none()); + self.buffered_request = Some(request); + return Ok(()) + } + result => result?, + }; + self.requests.insert(stream_id, request); + Ok(()) + } + + async fn recv_body(&mut self, stream_id: u64) -> Result<()> { + const STREAM_READ_CHUNK: usize = 4096; + if let Some(stream) = self.streams.get_mut(&stream_id) { + loop { + let base_len = stream.data.len(); + stream.data.resize(base_len + STREAM_READ_CHUNK, 0); + match self.h3_conn.recv_body( + &mut self.driver.quiche_conn, + stream_id, + &mut stream.data[base_len..], + ) { + Err(h3::Error::Done) => { + stream.data.truncate(base_len); + return Ok(()); + } + Err(e) => { + stream.data.truncate(base_len); + return Err(e.into()); + } + Ok(recvd) => stream.data.truncate(base_len + recvd), + } + } + } else { + warn!("Received body for untracked stream ID {}", stream_id); + } + Ok(()) + } + + fn discard_datagram(&mut self, _flow_id: u64) -> Result<()> { + loop { + match self.h3_conn.recv_dgram(&mut self.driver.quiche_conn, self.driver.buffer.as_mut()) + { + Err(h3::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + _ => (), + } + } + } + + async fn flush_h3(&mut self) -> Result<()> { + loop { + match self.h3_conn.poll(&mut self.driver.quiche_conn) { + Err(h3::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + Ok((stream_id, event)) => self.process_h3_event(stream_id, event).await?, + } + } + } + + async fn process_h3_event(&mut self, stream_id: u64, event: h3::Event) -> Result<()> { + if !self.requests.contains_key(&stream_id) { + warn!("Received event {:?} for stream_id {} without a request.", event, stream_id); + } + match event { + h3::Event::Headers { list, has_body } => { + let stream = Stream::new(list); + if self.streams.insert(stream_id, stream).is_some() { + warn!("Re-using stream ID {} before it was completed.", stream_id) + } + if !has_body { + self.respond(stream_id); + } + } + h3::Event::Data => { + self.recv_body(stream_id).await?; + } + h3::Event::Finished => self.respond(stream_id), + // This clause is for quiche 0.10.x, we're still on 0.9.x + //h3::Event::Reset(e) => { + // self.streams.get_mut(&stream_id).map(|stream| stream.error = Some(e)); + // self.respond(stream_id); + //} + h3::Event::Datagram => { + warn!("Unexpected Datagram received"); + // We don't care if something went wrong with the datagram, we didn't + // want it anyways. + let _ = self.discard_datagram(stream_id); + } + h3::Event::GoAway => self.shutdown(false, b"SERVER GOAWAY").await?, + } + Ok(()) + } + + async fn shutdown(&mut self, send_goaway: bool, msg: &[u8]) -> Result<()> { + self.driver.request_rx.close(); + while self.driver.request_rx.recv().await.is_some() {} + self.closing = true; + if send_goaway { + self.h3_conn.send_goaway(&mut self.driver.quiche_conn, 0)?; + } + if self.driver.quiche_conn.close(true, 0, msg).is_err() { + warn!("Trying to close already closed QUIC connection"); + } + Ok(()) + } + + fn respond(&mut self, stream_id: u64) { + match (self.streams.remove(&stream_id), self.requests.remove(&stream_id)) { + (Some(stream), Some(request)) => { + // We don't care about the error, because it means the requestor has left. + let _ = request.response_tx.send(stream); + } + (None, _) => warn!("Tried to deliver untracked stream {}", stream_id), + (_, None) => warn!("Tried to deliver stream {} to untracked requestor", stream_id), + } + } +} |